Skip to content

Commit

Permalink
Remove CUDA dependencies from jaxlib OSS wheel.
Browse files Browse the repository at this point in the history
With this change `jaxlib` wheel content is identical for CPU and GPU configurations. It enables reusing bazel cache when building all three targets together with `--config=cuda`: `build_wheel`, `build_gpu_plugin_wheel` and `build_gpu_kernels_wheel`.

PiperOrigin-RevId: 706016685
  • Loading branch information
Google-ML-Automation committed Dec 23, 2024
1 parent e19d979 commit 92e7e5a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
11 changes: 7 additions & 4 deletions third_party/tsl/tsl/profiler/lib/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@xla//xla/tsl:tsl.bzl", "if_not_android", "if_oss", "internal_visibility", "nvtx_headers")
load("@xla//xla/tsl:tsl.bzl", "if_google", "if_not_android", "if_oss", "internal_visibility", "nvtx_headers")
load("@xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
load("@xla//xla/tsl/platform:build_config.bzl", "tsl_cc_test")
load("@xla//xla/tsl/platform:build_config_root.bzl", "if_static")
Expand Down Expand Up @@ -278,14 +278,17 @@ cc_library(

cc_library(
name = "nvtx_utils_impl",
srcs = if_cuda_is_configured(
["nvtx_utils.cc"],
srcs = if_google(
if_cuda_is_configured(
["nvtx_utils.cc"],
["nvtx_utils_stub.cc"],
),
["nvtx_utils_stub.cc"],
),
hdrs = ["nvtx_utils.h"],
local_defines = if_oss(["NVTX_VERSION_3_1=1"]),
visibility = ["//visibility:public"],
deps = if_cuda_is_configured(nvtx_headers()),
deps = if_google(if_cuda_is_configured(nvtx_headers())),
)

cc_library(
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/profiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package_group(

tsl_gpu_library(
name = "profiler_backends",
add_gpu_deps_for_oss = False,
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
visibility = internal_visibility(["//xla:internal"]),
deps = [
"//xla/backends/profiler/cpu:host_tracer",
Expand Down
20 changes: 4 additions & 16 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ cc_library(
"-fexceptions",
"-fno-strict-aliasing",
],
defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm([
defines = if_google(if_cuda(["GOOGLE_CUDA=1"]) + if_rocm([
"TENSORFLOW_USE_ROCM=1",
]),
])),
features = ["-use_header_modules"],
deps = [
":aggregate_profile",
Expand Down Expand Up @@ -438,9 +438,9 @@ cc_library(
# keep sorted
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
]) + if_cuda_or_rocm([
]) + if_google(if_cuda_or_rocm([
":py_client_gpu", # TODO(b/337876408): remove after migration to plugin
]) + if_google(["@com_google_protobuf//:any_cc_proto"]),
]) + ["@com_google_protobuf//:any_cc_proto"]),
)

cc_library(
Expand Down Expand Up @@ -1244,18 +1244,6 @@ tsl_pybind_extension(
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
linkopts = select({
":use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../nvidia/cuda_cupti/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/cuda_runtime/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/cublas/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/cufft/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/cudnn/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/cusolver/lib",
"-Wl,-rpath,$$ORIGIN/../nvidia/nccl/lib",
],
"//conditions:default": [],
}),
pytype_deps = [
"//third_party/py/numpy",
],
Expand Down
2 changes: 2 additions & 0 deletions xla/tsl/distributed_runtime/coordination/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cc_library(
tsl_gpu_library(
name = "coordination_service_impl",
srcs = ["coordination_service.cc"],
add_gpu_deps_for_oss = False,
deps = [
":coordination_client",
":coordination_service",
Expand Down Expand Up @@ -141,6 +142,7 @@ tsl_gpu_library(
name = "coordination_service_agent",
srcs = ["coordination_service_agent.cc"],
hdrs = ["coordination_service_agent.h"],
add_gpu_deps_for_oss = False,
deps = [
":coordination_client",
":coordination_service_error_util",
Expand Down
31 changes: 21 additions & 10 deletions xla/tsl/tsl.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,12 @@ def tf_openmp_copts():
"//conditions:default": [],
})

def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs):
def tsl_gpu_library(
deps = None,
cuda_deps = None,
copts = tsl_copts(),
add_gpu_deps_for_oss = True,
**kwargs):
"""Generate a cc_library with a conditional set of CUDA dependencies.
When the library is built with --config=cuda:
Expand All @@ -373,6 +378,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs
'--config=cuda' is passed to the bazel command line.
deps: dependencies which will always be linked.
copts: copts always passed to the cc_library.
add_gpu_deps_for_oss: Whether to add gpu deps for OSS too.
**kwargs: Any other argument to cc_library.
"""
if not deps:
Expand All @@ -381,19 +387,24 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs
cuda_deps = []

kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
deps = deps + if_cuda(cuda_deps)
deps = deps + (if_cuda(cuda_deps) if add_gpu_deps_for_oss else if_google(if_cuda(cuda_deps)))
if "default_copts" in kwargs:
copts = kwargs["default_copts"] + copts
kwargs.pop("default_copts", None)
all_cuda_deps = if_cuda([
clean_dep("//xla/tsl/cuda:cudart"),
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm([
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
])
all_cuda_copts = if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"])
if not add_gpu_deps_for_oss:
all_cuda_deps = if_google(all_cuda_deps)
all_cuda_copts = if_google(all_cuda_copts)
cc_library(
deps = deps + if_cuda([
clean_dep("//xla/tsl/cuda:cudart"),
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm([
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
]),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
deps = deps + all_cuda_deps,
copts = (copts + all_cuda_copts + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs
)

Expand Down

0 comments on commit 92e7e5a

Please sign in to comment.