Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlib v0.4.38 #295

Merged
merged 7 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion recipe/build.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/bin/bash

set -euxo pipefail

# see comment in meta.yaml
cp $RECIPE_DIR/patches/0007-xla-cpu-Fix-build-errors-from-ACL.patch $SRC_DIR/third_party/xla/

export JAX_RELEASE=$PKG_VERSION

$RECIPE_DIR/add_py_toolchain.sh
Expand Down
10 changes: 8 additions & 2 deletions recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% set version = "0.4.36" %}
{% set version = "0.4.38" %}
{% set build = 0 %}

{% if cuda_compiler_version != "None" %}
Expand All @@ -13,10 +13,16 @@ package:
source:
# only pull sources after upstream PyPI release...
url: https://github.com/google/jax/archive/jax-v{{ version }}.tar.gz
sha256: 442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77
sha256: ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c
patches:
- patches/0001-Allow-for-custom-CUDA-build.patch
- patches/0002-Consolidated-build-fixes-for-XLA.patch
# cannot absorb this into the overall xla patch, because patching three vendored projects
# deep breaks the application of the patch using `patch` (presumably due to lines starting
# legitimately with `+++` being misinterpreted as a hunk separator).
# - patches/0007-xla-cpu-Fix-build-errors-from-ACL.patch
- patches/0003-fix-member-access-to-packed-CUDA-struct.patch
- patches/0004-fix-an-ambiguous-type.patch

build:
number: {{ build }}
Expand Down
30 changes: 8 additions & 22 deletions recipe/patches/0001-Allow-for-custom-CUDA-build.patch
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
From 803a48af47351366a13c1066c2049d3d8ec03767 Mon Sep 17 00:00:00 2001
From 47a24be649ae10c722a43f12dc7032b63d3c77b1 Mon Sep 17 00:00:00 2001
From: "Uwe L. Korn" <[email protected]>
Date: Sun, 8 Oct 2023 19:34:34 +0200
Subject: [PATCH 1/2] Allow for custom CUDA build
Subject: [PATCH 1/4] Allow for custom CUDA build

---
build/build.py | 14 ++++++++++----
build/tools/utils.py | 2 +-
2 files changed, 11 insertions(+), 5 deletions(-)
build/build.py | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/build/build.py b/build/build.py
index 25a873d..a83aec6 100755
index 5f43c2421..6939cc374 100755
--- a/build/build.py
+++ b/build/build.py
@@ -515,6 +515,13 @@ async def main():
@@ -527,6 +527,13 @@ async def main():

if args.cuda_version:
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
Expand All @@ -26,11 +25,11 @@ index 25a873d..a83aec6 100755
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
)
@@ -597,10 +604,9 @@ async def main():
@@ -609,10 +616,9 @@ async def main():

wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")

- result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
- result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
- # Exit with error if any wheel build fails.
- if result.return_code != 0:
- raise RuntimeError(f"Command failed with return code {result.return_code}")
Expand All @@ -40,16 +39,3 @@ index 25a873d..a83aec6 100755

# Exit with success if all wheels in the list were built successfully.
sys.exit(0)
diff --git a/build/tools/utils.py b/build/tools/utils.py
index 5d7c8e0..1a4b26f 100644
--- a/build/tools/utils.py
+++ b/build/tools/utils.py
@@ -222,7 +222,7 @@ def get_githash():
capture_output=True,
check=True,
).stdout.strip()
- except OSError:
+ except (subprocess.CalledProcessError, OSError):
return ""

def _parse_string_as_bool(s):
62 changes: 32 additions & 30 deletions recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
From 6a5e829b219f840595084bc5ce05bf8fb96e80c5 Mon Sep 17 00:00:00 2001
From a56d0a224f537e13570ddf60bb83300c30ce417b Mon Sep 17 00:00:00 2001
From: "Uwe L. Korn" <[email protected]>
Date: Thu, 14 Dec 2023 17:06:15 +0100
Subject: [PATCH 2/2] Consolidated build fixes for XLA
Subject: [PATCH 2/4] Consolidated build fixes for XLA

jax vendors xla, but only populates the sources through bazel, so we cannot
patch as usual through conda, but rather need to teach the bazel build file
for xla to apply those patches.

To maintain/rebase these patches, use a checkout of https://github.com/openxla/xla,
and then rebase to the commit of xla being used by jax, which can be found in
https://github.com/google/jax/blob/jaxlib-v{{ version }}/third_party/xla/workspace.bzl
https://github.com/google/jax/blob/jax-v{{ version }}/third_party/xla/workspace.bzl
which is also where we're patching in the list of patches to apply to xla.

Co-Authored-By: H. Vetinari <[email protected]>
Expand All @@ -20,8 +20,8 @@ Co-Authored-By: H. Vetinari <[email protected]>
...0004-Add-missing-bits-absl-systemlib.patch | 245 ++++++++++++++++++
...ther-absl-log-is-already-initialized.patch | 58 +++++
.../xla/0006-Add-conda-cuda-path.patch | 31 +++
third_party/xla/workspace.bzl | 8 +
7 files changed, 491 insertions(+)
third_party/xla/workspace.bzl | 10 +
7 files changed, 493 insertions(+)
create mode 100644 third_party/xla/0001-Support-third-party-build-of-boringssl.patch
create mode 100644 third_party/xla/0002-Fix-abseil-headers.patch
create mode 100644 third_party/xla/0003-Omit-usage-of-StrFormat.patch
Expand All @@ -31,14 +31,14 @@ Co-Authored-By: H. Vetinari <[email protected]>

diff --git a/third_party/xla/0001-Support-third-party-build-of-boringssl.patch b/third_party/xla/0001-Support-third-party-build-of-boringssl.patch
new file mode 100644
index 0000000..f1ca8bf
index 000000000..6337bd8ee
--- /dev/null
+++ b/third_party/xla/0001-Support-third-party-build-of-boringssl.patch
@@ -0,0 +1,51 @@
+From 0091b1688bec824633dad1d7bb95380be76fd9fd Mon Sep 17 00:00:00 2001
+From 49a684836e4cde37d326aaac66fd7ed74e396f3b Mon Sep 17 00:00:00 2001
+From: "Uwe L. Korn" <[email protected]>
+Date: Thu, 14 Dec 2023 15:04:51 +0100
+Subject: [PATCH 1/6] Support third-party build of boringssl
+Subject: [PATCH 1/7] Support third-party build of boringssl
+
+---
+ third_party/boringssl.BUILD | 21 +++++++++++++++++++++
Expand Down Expand Up @@ -74,10 +74,10 @@ index 0000000..f1ca8bf
++ ],
++)
+diff --git a/workspace2.bzl b/workspace2.bzl
+index 84ba78119b..70fefac97a 100644
+index 20ffe0cf82..a765265bad 100644
+--- a/workspace2.bzl
++++ b/workspace2.bzl
+@@ -69,7 +69,7 @@ def _tf_repositories():
+@@ -108,7 +108,7 @@ def _tf_repositories():
+ name = "boringssl",
+ sha256 = "9dc53f851107eaf87b391136d13b815df97ec8f76dadb487b58b2fc45e624d2c",
+ strip_prefix = "boringssl-c00d7ca810e93780bd0c8ee4eea28f4f2ea4bcdc",
Expand All @@ -88,14 +88,14 @@ index 0000000..f1ca8bf
+
diff --git a/third_party/xla/0002-Fix-abseil-headers.patch b/third_party/xla/0002-Fix-abseil-headers.patch
new file mode 100644
index 0000000..118b60e
index 000000000..dc4d73303
--- /dev/null
+++ b/third_party/xla/0002-Fix-abseil-headers.patch
@@ -0,0 +1,73 @@
+From 583296728feafa21677b4cf8d4e9eb9231b5d9fc Mon Sep 17 00:00:00 2001
+From bd9bdaf6ef5d57905620b167de2bba35d54ea930 Mon Sep 17 00:00:00 2001
+From: "Uwe L. Korn" <[email protected]>
+Date: Thu, 23 May 2024 15:45:52 +0200
+Subject: [PATCH 2/6] Fix abseil headers
+Subject: [PATCH 2/7] Fix abseil headers
+
+---
+ xla/python/ifrt_proxy/common/BUILD | 3 +++
Expand All @@ -105,7 +105,7 @@ index 0000000..118b60e
+ 4 files changed, 10 insertions(+)
+
+diff --git a/xla/python/ifrt_proxy/common/BUILD b/xla/python/ifrt_proxy/common/BUILD
+index b595a162ef..1a5f5d9a5f 100644
+index 506bcc0c47..11acfae12a 100644
+--- a/xla/python/ifrt_proxy/common/BUILD
++++ b/xla/python/ifrt_proxy/common/BUILD
+@@ -51,6 +51,9 @@ cc_library(
Expand All @@ -131,10 +131,10 @@ index 0000000..118b60e
+ "@tsl//tsl/platform:protobuf",
+ "@tsl//tsl/platform:status",
+diff --git a/xla/tsl/platform/default/BUILD b/xla/tsl/platform/default/BUILD
+index 86531a9c60..e427e1f1f0 100644
+index 0875f9a399..82379ca8df 100644
+--- a/xla/tsl/platform/default/BUILD
++++ b/xla/tsl/platform/default/BUILD
+@@ -229,6 +229,8 @@ cc_library(
+@@ -230,6 +230,8 @@ cc_library(
+ ],
+ deps = [
+ "@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -167,14 +167,14 @@ index 0000000..118b60e
+ "@com_google_absl//absl/time",
diff --git a/third_party/xla/0003-Omit-usage-of-StrFormat.patch b/third_party/xla/0003-Omit-usage-of-StrFormat.patch
new file mode 100644
index 0000000..048a570
index 000000000..293211c8e
--- /dev/null
+++ b/third_party/xla/0003-Omit-usage-of-StrFormat.patch
@@ -0,0 +1,25 @@
+From 7b7c1a3120412b65080050e17effd8de753eba2d Mon Sep 17 00:00:00 2001
+From e4c2b38053744f0ce7a7cd734f0415504a514bc1 Mon Sep 17 00:00:00 2001
+From: "Uwe L. Korn" <[email protected]>
+Date: Thu, 4 Jul 2024 10:36:03 +0200
+Subject: [PATCH 3/6] Omit usage of StrFormat
+Subject: [PATCH 3/7] Omit usage of StrFormat
+
+---
+ xla/stream_executor/gpu/gpu_executor.h | 5 +++++
Expand All @@ -198,14 +198,14 @@ index 0000000..048a570
+ #include <vector>
diff --git a/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch b/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch
new file mode 100644
index 0000000..a5fdf63
index 000000000..f63bf15d5
--- /dev/null
+++ b/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch
@@ -0,0 +1,245 @@
+From db903958859521a6e2898690e893b1c11097139c Mon Sep 17 00:00:00 2001
+From 0ab5601817af80651b62decc237f01d461dc768b Mon Sep 17 00:00:00 2001
+From: "Uwe L. Korn" <[email protected]>
+Date: Thu, 4 Jul 2024 15:58:32 +0200
+Subject: [PATCH 4/6] Add missing bits absl systemlib
+Subject: [PATCH 4/7] Add missing bits absl systemlib
+
+---
+ .../third_party/absl/system.absl.base.BUILD | 16 +++++
Expand Down Expand Up @@ -449,14 +449,14 @@ index 0000000..a5fdf63
+ "numeric",
diff --git a/third_party/xla/0005-Check-whether-absl-log-is-already-initialized.patch b/third_party/xla/0005-Check-whether-absl-log-is-already-initialized.patch
new file mode 100644
index 0000000..76879ff
index 000000000..94007c0a6
--- /dev/null
+++ b/third_party/xla/0005-Check-whether-absl-log-is-already-initialized.patch
@@ -0,0 +1,58 @@
+From f576357c00ca90f74037b56167020f90a65a945c Mon Sep 17 00:00:00 2001
+From e95ed3a620df2ba0513ee44b50eacef391b50088 Mon Sep 17 00:00:00 2001
+From: "Uwe L. Korn" <[email protected]>
+Date: Fri, 22 Nov 2024 10:51:18 +0100
+Subject: [PATCH 5/6] Check whether absl log is already initialized
+Subject: [PATCH 5/7] Check whether absl log is already initialized
+
+---
+ xla/pjrt/c/pjrt_c_api_gpu.cc | 7 ++++++-
Expand Down Expand Up @@ -513,14 +513,14 @@ index 0000000..76879ff
+ } // namespace xla
diff --git a/third_party/xla/0006-Add-conda-cuda-path.patch b/third_party/xla/0006-Add-conda-cuda-path.patch
new file mode 100644
index 0000000..e3381f6
index 000000000..93999484f
--- /dev/null
+++ b/third_party/xla/0006-Add-conda-cuda-path.patch
@@ -0,0 +1,31 @@
+From 5ef91efdec33a1e7cdf963c545b0dd1f232096bd Mon Sep 17 00:00:00 2001
+From 25667f2b9448fff8cc462dd386c85c96740c64af Mon Sep 17 00:00:00 2001
+From: Silvio Traversaro <[email protected]>
+Date: Tue, 10 Dec 2024 09:42:51 +0100
+Subject: [PATCH 6/6] Add conda cuda path
+Subject: [PATCH 6/7] Add conda cuda path
+
+---
+ xla/tsl/platform/default/cuda_root_path.cc | 11 +++++++++++
Expand Down Expand Up @@ -549,10 +549,10 @@ index 0000000..e3381f6
+ #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__)
+
diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
index db34354..4fd524a 100644
index d08d9bba9..4db026b92 100644
--- a/third_party/xla/workspace.bzl
+++ b/third_party/xla/workspace.bzl
@@ -30,6 +30,14 @@ def repo():
@@ -30,6 +30,16 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
Expand All @@ -563,6 +563,8 @@ index db34354..4fd524a 100644
+ "//third_party/xla:0004-Add-missing-bits-absl-systemlib.patch",
+ "//third_party/xla:0005-Check-whether-absl-log-is-already-initialized.patch",
+ "//third_party/xla:0006-Add-conda-cuda-path.patch",
+ # backport https://github.com/openxla/xla/commit/14ba309d98db689ba5c185287483d555e3307b7f
+ "//third_party/xla:0007-xla-cpu-Fix-build-errors-from-ACL.patch",
+ ],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
From be74d15a31142e7c19da9e787b148748cd4bae1d Mon Sep 17 00:00:00 2001
From: "H. Vetinari" <[email protected]>
Date: Sat, 22 Feb 2025 16:39:41 +1100
Subject: [PATCH 3/4] fix member access to packed CUDA struct

---
jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
index 8889a4983..f56b62fd6 100644
--- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
+++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
@@ -181,8 +181,9 @@ void callback_complete(CUcontext context, uint32_t streamId,
// Convert integer nanoseconds to floating point milliseconds to match
// the interface of the events-based profiler.
double duration_ms = (kernel->end - kernel->start) / 1e6;
+ const char* kernel_name = kernel->name;
profiler_state.timings.push_back(
- std::make_tuple(kernel->name, duration_ms));
+ std::make_tuple(kernel_name, duration_ms));
}
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
// no more records available
22 changes: 22 additions & 0 deletions recipe/patches/0004-fix-an-ambiguous-type.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From 6e9cd86949724f046df92fc8b7348939f2fdc0be Mon Sep 17 00:00:00 2001
From: "H. Vetinari" <[email protected]>
Date: Sun, 23 Feb 2025 02:27:18 +1100
Subject: [PATCH 4/4] fix an ambiguous type

---
jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
index b21f56327..5efc1a39e 100644
--- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
+++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
@@ -412,7 +412,7 @@ llvm::LogicalResult WGMMAOp::verify() {

int groups_m = 0;
auto a_shape = a_shaped_type.getShape();
- if (auto a_memref = dyn_cast<mlir::MemRefType>(getA().getType())) {
+ if (auto a_memref = mlir::dyn_cast<mlir::MemRefType>(getA().getType())) {
if (a_shape.size() != 4) {
return error("When `a` is a memref, it must have rank 4.");
}
Loading
Loading