diff --git a/WORKSPACE b/WORKSPACE index 57e590ac5e0b..fef3937b8991 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # and update the sha256 with the result. http_archive( name = "xla", - sha256 = "617a968b2d4154ef4368e2676c72e2bc9a019be3b6a1941c8dc741d3e5ea3d8e", - strip_prefix = "xla-7a7cee6e31a01d0103c41b753c7e7fe6e0eeece8", + sha256 = "f8f6efd93237c94ee3c8be65dd0702a065c3e3e7b0f732181e62d33813f1bd35", + strip_prefix = "xla-34521d16cd91423a041be86bcb35319dfee3bedb", urls = [ - "https://github.com/openxla/xla/archive/7a7cee6e31a01d0103c41b753c7e7fe6e0eeece8.tar.gz", + "https://github.com/openxla/xla/archive/34521d16cd91423a041be86bcb35319dfee3bedb.tar.gz", ], ) diff --git a/setup.py b/setup.py index 2dfac16f4f79..8fe4e36c4694 100644 --- a/setup.py +++ b/setup.py @@ -19,13 +19,13 @@ from setuptools import setup, find_packages -_current_jaxlib_version = '0.4.11' +_current_jaxlib_version = '0.4.12' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.11' _available_cuda11_cudnn_versions = ['82', '86'] _default_cuda11_cudnn_version = '86' _default_cuda12_cudnn_version = '88' -_libtpu_version = '0.1.dev20230531' +_libtpu_version = '0.1.dev20230608' _dct = {} with open('jax/version.py', encoding='utf-8') as f: