From ccd500cea1fa08f65c680d67b02d444c29422981 Mon Sep 17 00:00:00 2001 From: eaplatanios Date: Tue, 30 Jul 2024 16:14:37 -0700 Subject: [PATCH] Added support for compiling the CUDA stubs on Windows. --- xla/tsl/cuda/BUILD.bazel | 97 ++++++++++++++++++++--------------- xla/tsl/cuda/cublasLt_stub.cc | 4 ++ xla/tsl/cuda/cublas_stub.cc | 4 ++ xla/tsl/cuda/cuda_stub.cc | 4 ++ xla/tsl/cuda/cudart_stub.cc | 4 ++ xla/tsl/cuda/cudnn_stub.cc | 5 ++ xla/tsl/cuda/cufft_stub.cc | 4 ++ xla/tsl/cuda/cupti_stub.cc | 4 ++ xla/tsl/cuda/cusolver_stub.cc | 4 ++ xla/tsl/cuda/cusparse_stub.cc | 4 ++ xla/tsl/cuda/nccl_stub.cc | 4 ++ xla/tsl/cuda/stub.bzl | 1 + 12 files changed, 98 insertions(+), 41 deletions(-) diff --git a/xla/tsl/cuda/BUILD.bazel b/xla/tsl/cuda/BUILD.bazel index 5c375cace438a..469706ddf42d5 100644 --- a/xla/tsl/cuda/BUILD.bazel +++ b/xla/tsl/cuda/BUILD.bazel @@ -1,6 +1,7 @@ # Description: # Stubs for dynamically loading CUDA. +load("@bazel_skylib//lib:selects.bzl", "selects") load( "@tsl//tsl/platform:rules_cc.bzl", "cc_library", @@ -16,6 +17,14 @@ package( licenses = ["notice"], ) +selects.config_setting_group( + name = "linux_with_cuda_enabled", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@platforms//os:linux", + ], +) + cuda_stub( name = "cublas", srcs = ["cublas.symbols"], @@ -23,10 +32,11 @@ cuda_stub( cc_library( name = "cublas", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cublas_stub.cc", - "cublas.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cublas_stub.cc", "cublas.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cublas_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags( "nvidia/cublas/lib", )), @@ -51,10 +61,11 @@ cuda_stub( cc_library( name = "cublas_lt", - srcs = if_cuda_is_configured([ - "cublasLt_stub.cc", - "cublasLt.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cublasLt_stub.cc", "cublasL.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cublasLt_stub.cc"], + "//conditions:default": [], + }), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", ], @@ -75,10 +86,11 @@ cuda_stub( cc_library( name = "cuda", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cuda_stub.cc", - "cuda.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cuda_stub.cc", "cuda.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cuda_stub.cc"], + "//conditions:default": [], + }), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", ], @@ -100,11 +112,8 @@ cuda_stub( cc_library( name = "cudart", # buildifier: disable=duplicated-name srcs = select({ - # include dynamic loading implementation only when if_cuda_is_configured and build dynamically - "@xla//xla/tsl:is_cuda_enabled_and_oss": [ - "cudart.tramp.S", - "cudart_stub.cc", - ], + ":linux_with_cuda_enabled": ["cudart_stub.cc", "cudart.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cudart_stub.cc"], "//conditions:default": [], }), linkopts = select({ @@ -136,10 +145,11 @@ cuda_stub( cc_library( name = "cudnn", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cudnn_stub.cc", - "cudnn.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cudnn_stub.cc", "cudnn.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cudnn_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cudnn/lib")), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", @@ -174,10 +184,11 @@ cuda_stub( cc_library( name = "cufft", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cufft_stub.cc", - "cufft.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cufft_stub.cc", "cufft.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cufft_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cufft/lib")), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", @@ -199,10 +210,11 @@ cuda_stub( cc_library( name = "cupti", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cupti_stub.cc", - "cupti.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cupti_stub.cc", "cupti.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cupti_stub.cc"], + "//conditions:default": [], + }), data = if_cuda_is_configured(["@local_config_cuda//cuda:cupti_dsos"]), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cuda_cupti/lib")), local_defines = [ @@ -226,10 +238,11 @@ cuda_stub( cc_library( name = "cusolver", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cusolver_stub.cc", - "cusolver.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cusolver_stub.cc", "cusolver.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cusolver_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusolver/lib")), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", @@ -251,10 +264,11 @@ cuda_stub( cc_library( name = "cusparse", # buildifier: disable=duplicated-name - srcs = if_cuda_is_configured([ - "cusparse_stub.cc", - "cusparse.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["cusparse_stub.cc", "cusparse.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["cusparse_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusparse/lib")), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", @@ -277,10 +291,11 @@ cuda_stub( cc_library( name = "nccl_stub", - srcs = if_cuda_is_configured([ - "nccl_stub.cc", - "nccl.tramp.S", - ]), + srcs = select({ + ":linux_with_cuda_enabled": ["nccl_stub.cc", "nccl.tramp.S"], + "@local_config_cuda//:is_cuda_enabled": ["nccl_stub.cc"], + "//conditions:default": [], + }), linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), local_defines = [ "IMPLIB_EXPORT_SHIMS=1", diff --git a/xla/tsl/cuda/cublasLt_stub.cc b/xla/tsl/cuda/cublasLt_stub.cc index db60995d59fa5..cada3c2b0d56b 100644 --- a/xla/tsl/cuda/cublasLt_stub.cc +++ b/xla/tsl/cuda/cublasLt_stub.cc @@ -20,6 +20,9 @@ limitations under the License. // Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -67,3 +70,4 @@ void _cublasLt_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cublas_stub.cc b/xla/tsl/cuda/cublas_stub.cc index a4b7fcbb828b6..12662c8fb4fb9 100644 --- a/xla/tsl/cuda/cublas_stub.cc +++ b/xla/tsl/cuda/cublas_stub.cc @@ -30,6 +30,9 @@ limitations under the License. // Implements the cuBLAS API by forwarding to cuBLAS loaded from the DSO. // Note that it does not implement the v1 interface. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void *GetDsoHandle() { @@ -244,3 +247,4 @@ void _cublas_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cuda_stub.cc b/xla/tsl/cuda/cuda_stub.cc index e33535c16e33c..b0cf4420da060 100644 --- a/xla/tsl/cuda/cuda_stub.cc +++ b/xla/tsl/cuda/cuda_stub.cc @@ -19,6 +19,9 @@ limitations under the License. // Implements the CUDA driver API by forwarding to CUDA loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -70,3 +73,4 @@ void _cuda_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cudart_stub.cc b/xla/tsl/cuda/cudart_stub.cc index 7064a72541eef..f22c89b7086a3 100644 --- a/xla/tsl/cuda/cudart_stub.cc +++ b/xla/tsl/cuda/cudart_stub.cc @@ -24,6 +24,9 @@ limitations under the License. #include "tsl/platform/load_library.h" #include "tsl/platform/logging.h" +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { void *GetDsoHandle() { static auto handle = []() -> void * { @@ -89,3 +92,4 @@ void _cudart_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cudnn_stub.cc b/xla/tsl/cuda/cudnn_stub.cc index 192009c9e8728..efa3e58027dc1 100644 --- a/xla/tsl/cuda/cudnn_stub.cc +++ b/xla/tsl/cuda/cudnn_stub.cc @@ -21,6 +21,9 @@ limitations under the License. // Implements the cuDNN API by forwarding to cuDNN loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -94,3 +97,5 @@ void _cudnn_tramp_resolve(int i) { } } // extern "C" + +#endif diff --git a/xla/tsl/cuda/cufft_stub.cc b/xla/tsl/cuda/cufft_stub.cc index ea7b08f882189..534ceac97d882 100644 --- a/xla/tsl/cuda/cufft_stub.cc +++ b/xla/tsl/cuda/cufft_stub.cc @@ -20,6 +20,9 @@ limitations under the License. // Implements the cuFFT API by forwarding to cuFFT loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -69,3 +72,4 @@ void _cufft_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cupti_stub.cc b/xla/tsl/cuda/cupti_stub.cc index 01d13a8ea7d4f..2cc0e6c331945 100644 --- a/xla/tsl/cuda/cupti_stub.cc +++ b/xla/tsl/cuda/cupti_stub.cc @@ -21,6 +21,9 @@ limitations under the License. // Implements the CUPTI API by forwarding to CUPTI loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -70,3 +73,4 @@ void _cupti_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cusolver_stub.cc b/xla/tsl/cuda/cusolver_stub.cc index d76526042582e..d85fa1a8ee856 100644 --- a/xla/tsl/cuda/cusolver_stub.cc +++ b/xla/tsl/cuda/cusolver_stub.cc @@ -21,6 +21,9 @@ limitations under the License. // Implements the cusolver API by forwarding to cusolver loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -72,3 +75,4 @@ void _cusolver_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/cusparse_stub.cc b/xla/tsl/cuda/cusparse_stub.cc index b8ab1d67354bd..2b1003b0e9e1b 100644 --- a/xla/tsl/cuda/cusparse_stub.cc +++ b/xla/tsl/cuda/cusparse_stub.cc @@ -24,6 +24,9 @@ limitations under the License. // Implements the cusparse API by forwarding to cusparse loaded from the DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -92,3 +95,4 @@ void _cusparse_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/nccl_stub.cc b/xla/tsl/cuda/nccl_stub.cc index f3895da245176..3696d8a2bf460 100644 --- a/xla/tsl/cuda/nccl_stub.cc +++ b/xla/tsl/cuda/nccl_stub.cc @@ -23,6 +23,9 @@ limitations under the License. // Implements the nccl API by forwarding to nccl loaded from a DSO. +// Note that we do not need this for MSVC because it already uses lazy loading. +#if !defined(_MSC_VER) + namespace { // Returns DSO handle or null if loading the DSO fails. void* GetDsoHandle() { @@ -91,3 +94,4 @@ void _nccl_tramp_resolve(int i) { } } // extern "C" +#endif diff --git a/xla/tsl/cuda/stub.bzl b/xla/tsl/cuda/stub.bzl index 2f98d8034510e..9b6c350d7c64f 100644 --- a/xla/tsl/cuda/stub.bzl +++ b/xla/tsl/cuda/stub.bzl @@ -22,6 +22,7 @@ def cuda_stub(name, srcs): "@xla//xla/tsl:linux_aarch64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target aarch64", "@xla//xla/tsl:linux_x86_64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target x86_64", "@xla//xla/tsl:linux_ppc64le": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target powerpc64le", + "@xla//xla/tsl:windows_x86_64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target x86_64", "//conditions:default": "NOT_IMPLEMENTED_FOR_THIS_PLATFORM_OR_ARCHITECTURE", }), )