Skip to content

Commit

Permalink
Links most cuda libs to jnitorch_cuda only.
Browse files Browse the repository at this point in the history
  • Loading branch information
HGuillemet committed Aug 11, 2024
1 parent 81e39b2 commit 1c621b3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
26 changes: 6 additions & 20 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,36 +101,20 @@
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.3/extras/CUPTI/lib64/",
"C:/Program Files/NVIDIA Corporation/NvToolsExt/bin/x64/",
},
linkpath = {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.3/lib/x64/",
"/usr/local/cuda-12.3/lib64/",
"/usr/local/cuda/lib64/",
"/usr/lib64/"
},
extension = "-gpu"
),
@Platform(
value = {"linux"},
link = { "c10", "torch", "torch_cpu" }
link = { "c10", "torch_cpu" }
),
@Platform(
value = {"macosx"},
link = { "c10", "torch", "torch_cpu", "omp" }
link = { "c10", "torch_cpu", "omp" }
),
@Platform(
value = "windows",
link = { "c10", "torch", "torch_cpu", "uv" }
link = { "c10", "torch_cpu", "uv" }
),
@Platform(
value = "linux",
extension = "-gpu",
link = { "c10", "torch", "torch_cpu", "c10_cuda", "torch_cuda", "torch_cuda_linalg", "cudart", "cusparse", "cudnn" } // [email protected] needed ? cuda_linalg built as separate lib on linux only
),
@Platform(
value = "windows",
extension = "-gpu",
link = { "c10", "torch", "torch_cpu", "uv", "c10_cuda", "torch_cuda", "cudart", "cusparse", "cudnn" }
)
},
target = "org.bytedeco.pytorch",
global = "org.bytedeco.pytorch.global.torch"
Expand Down Expand Up @@ -180,11 +164,13 @@ public void init(ClassProperties properties) {
if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) {
return;
}

// when built for CUDA, even torch_cpu links with at least cupti and cudart, for some reason
int i = 0;
if (platform.startsWith("windows")) {
preloads.add(i++, "zlibwapi");
}
String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "curand", "nvJitLink", "cusparse", "cusolver",
String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "cupti", "curand", "nvJitLink", "cusparse", "cusolver",
"cudnn", "nccl", "nvrtc", "nvrtc-builtins", "myelin", "nvinfer", "cudnn_ops_infer", "cudnn_ops_train",
"cudnn_adv_infer", "cudnn_adv_train", "cudnn_cnn_infer", "cudnn_cnn_train"};
for (String lib : libs) {
Expand Down
18 changes: 18 additions & 0 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch_cuda.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@
"<driver_functions.h>" // causes #warning
}
),
@Platform(
value = "linux",
extension = "-gpu",
link = { "c10", "torch_cpu", "c10_cuda", "torch_cuda", "torch_cuda_linalg", "cudart", "cupti", "cusparse", "cudnn" }, // cuda_linalg built as separate lib on linux only
linkpath = {
"/usr/local/cuda-12.3/lib64/",
"/usr/local/cuda/lib64/",
"/usr/lib64/"
}
),
@Platform(
value = "windows",
extension = "-gpu",
link = { "c10", "torch_cpu", "uv", "c10_cuda", "torch_cuda", "cudart", "cupti", "cusparse", "cudnn" },
linkpath = {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.3/lib/x64/"
}
)
},
target = "org.bytedeco.pytorch.cuda",
global = "org.bytedeco.pytorch.global.torch_cuda"
Expand Down

0 comments on commit 1c621b3

Please sign in to comment.