-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Address possible breaking changes from jax==0.4.36
#1434
Comments
dpanici
added
bug
Something isn't working
P3
Highest Priority, someone is/should be actively working on this
labels
Dec 6, 2024
Most of that issues look like they might be fixed upstream in equinox: patrick-kidger/equinox#907 |
Looks like its actually due to jax-ml/jax#25329 |
This was referenced Dec 7, 2024
Merged
f0uriest
added a commit
that referenced
this issue
Dec 12, 2024
Resolves #1434 Updates the requirements on [jax](https://github.com/jax-ml/jax) and [diffrax](https://github.com/patrick-kidger/diffrax) to permit the latest version. Updates `jax` to 0.4.37 <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/jax-ml/jax/releases">jax's releases</a>.</em></p> <blockquote> <h2>JAX v0.4.37</h2> <p>This is a patch release of jax 0.4.36. Only "jax" was released at this version.</p> <ul> <li>Bug fixes <ul> <li>Fixed a bug where <code>jit</code> would error if an argument was named <code>f</code> (<a href="https://redirect.github.com/jax-ml/jax/issues/25329">#25329</a>).</li> <li>Fix a bug that will throw <code>index out of range</code> error in <code>jax.lax.while_loop</code> if the user registers pytree node class with different aux data for the flatten and flatten_with_path.</li> <li>Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.</li> </ul> </li> </ul> </blockquote> </details> <details> <summary>Changelog</summary> <p><em>Sourced from <a href="https://github.com/jax-ml/jax/blob/main/CHANGELOG.md">jax's changelog</a>.</em></p> <blockquote> <h2>jax 0.4.37 (Dec 9, 2024)</h2> <p>This is a patch release of jax 0.4.36. Only "jax" was released at this version.</p> <ul> <li>Bug fixes <ul> <li>Fixed a bug where <code>jit</code> would error if an argument was named <code>f</code> (<a href="https://redirect.github.com/jax-ml/jax/issues/25329">#25329</a>).</li> <li>Fix a bug that will throw <code>index out of range</code> error in {func}<code>jax.lax.while_loop</code> if the user register pytree node class with different aux data for the flatten and flatten_with_path.</li> <li>Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.</li> </ul> </li> </ul> <h2>jax 0.4.36 (Dec 5, 2024)</h2> <ul> <li>Breaking Changes <ul> <li> <p>This release lands "stackless", an internal change to JAX's tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, <code>post_process_call</code>, <code>new_base_main</code>, <code>custom_bind</code>, and so on. The change should only affect users that use JAX internals.</p> <p>If you do use JAX internals then you may need to update your code (see <a href="https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f">https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f</a> for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the <code>config.jax_data_dependent_tracing_fallback</code> flag as a workaround, and if you need help updating your code then please file a bug.</p> </li> <li> <p>{func}<code>jax.experimental.jax2tf.convert</code> with <code>native_serialization=False</code> or with <code>enable_xla=False</code> have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. <code>jax2tf</code> with native serialization will still be supported.</p> </li> <li> <p>In <code>jax.interpreters.xla</code>, the <code>xb</code>, <code>xc</code>, and <code>xe</code> symbols have been removed after being deprecated in JAX v0.4.31. Instead use <code>xb = jax.lib.xla_bridge</code>, <code>xc = jax.lib.xla_client</code>, and <code>xe = jax.lib.xla_extension</code>.</p> </li> <li> <p>The deprecated module <code>jax.experimental.export</code> has been removed. It was replaced by {mod}<code>jax.export</code> in JAX v0.4.30. See the <a href="https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export">migration guide</a> for information on migrating to the new API.</p> </li> <li> <p>The <code>initial</code> argument to {func}<code>jax.nn.softmax</code> and {func}<code>jax.nn.log_softmax</code> has been removed, after being deprecated in v0.4.27.</p> </li> <li> <p>Calling <code>np.asarray</code> on typed PRNG keys (i.e. keys produced by :func:<code>jax.random.key</code>) now raises an error. Previously, this returned a scalar object array.</p> </li> <li> <p>The following deprecated methods and functions in {mod}<code>jax.export</code> have been removed:</p> <ul> <li><code>jax.export.DisabledSafetyCheck.shape_assertions</code>: it had no effect already.</li> <li><code>jax.export.Exported.lowering_platforms</code>: use <code>platforms</code>.</li> <li><code>jax.export.Exported.mlir_module_serialization_version</code>: use <code>calling_convention_version</code>.</li> </ul> </li> </ul> </li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Commits</summary> <ul> <li><a href="https://github.com/jax-ml/jax/commit/ffb07cdadb5dc3bc43485cf041dbc2b43136109e"><code>ffb07cd</code></a> Update versions for v0.4.37 release.</li> <li><a href="https://github.com/jax-ml/jax/commit/95892fdac86524151b6dadd7d8bedbf915f1500f"><code>95892fd</code></a> Use private names for args in api_util to avoid shadowing kwargs keys.</li> <li><a href="https://github.com/jax-ml/jax/commit/65b60884114261549ffc2eb937162bdeaa493928"><code>65b6088</code></a> Avoid index out of range error in carry structure check</li> <li><a href="https://github.com/jax-ml/jax/commit/259194a69f52a06847a9ff11eb268072e91fd65f"><code>259194a</code></a> [Pallas] Fix shard_axis in dma_start interpret mode rule.</li> <li><a href="https://github.com/jax-ml/jax/commit/7e6620a57775084dfa8d438ae4fd27f3ef365018"><code>7e6620a</code></a> JAX release 0.4.36.</li> <li><a href="https://github.com/jax-ml/jax/commit/23d5c10ff0704f66ad7ec65a8cdcd09bd2420591"><code>23d5c10</code></a> [Mosaic:TPU] Fix fully replicated relayout</li> <li><a href="https://github.com/jax-ml/jax/commit/2a4a0e8d6fb36b59f9c6f24e0018d42c8c8d8ee9"><code>2a4a0e8</code></a> [jax:custom_partitioning] Implement SdyShardingRule to support</li> <li><a href="https://github.com/jax-ml/jax/commit/f73fa7a7ad64b2f15e8669beed14600704287b93"><code>f73fa7a</code></a> Merge pull request <a href="https://redirect.github.com/jax-ml/jax/issues/25290">#25290</a> from jakevdp:reduction-where</li> <li><a href="https://github.com/jax-ml/jax/commit/a71f9a62e6f67640a4b0578d042b07792fcf407a"><code>a71f9a6</code></a> Merge pull request <a href="https://redirect.github.com/jax-ml/jax/issues/25271">#25271</a> from jakevdp:fix-vector-norm</li> <li><a href="https://github.com/jax-ml/jax/commit/e20a483befbb80bbf782b931ec57a44c78c313b8"><code>e20a483</code></a> [JAX] Add end-to-end execution support in colocated Python API</li> <li>Additional commits viewable in <a href="https://github.com/jax-ml/jax/compare/jax-v0.4.24...jax-v0.4.37">compare view</a></li> </ul> </details> <br /> Updates `diffrax` to 0.6.1 <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/patrick-kidger/diffrax/releases">diffrax's releases</a>.</em></p> <blockquote> <h2>Diffrax v0.6.1</h2> <h3>Features</h3> <ul> <li> <p>Compatibility with JAX 0.4.36.</p> </li> <li> <p>New solvers! Added stochastic Runge--Kutta methods for solving the underdamped Langevin equation. We now have:</p> <ul> <li><code>diffrax.AbstractFosterLangevinSRK</code></li> <li><code>diffrax.ALIGN</code></li> <li><code>diffrax.QUICSORT</code></li> <li><code>diffrax.ShOULD</code></li> </ul> <p>and these are used with the corresponding</p> <ul> <li><code>diffrax.UnderdampedLangevinDriftTerm</code></li> <li><code>diffrax.UnderdampedLangevinDiffusionTerm</code></li> </ul> <p>huge thanks to <a href="https://github.com/andyElking"><code>@andyElking</code></a> for carefully implementing all of these, which was a huge technical task. (<a href="https://redirect.github.com/patrick-kidger/diffrax/issues/453">#453</a> and 2000 new lines of code!) See <a href="https://docs.kidger.site/diffrax/examples/underdamped_langevin_example/">the Underdamped Langevin Diffusion example</a> for more on how to use these.</p> </li> </ul> <h3>Bugfixes</h3> <ul> <li>If <code>t0 == t1</code> and we have <code>SaveAt(ts=...)</code> then we now correctly output <code>len(ts)</code> copies of <code>y0</code>. (Thanks <a href="https://github.com/dkweiss31"><code>@dkweiss31</code></a>! <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/488">#488</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/494">#494</a>)</li> <li>When using <code>diffrax.VirtualBrownianTree</code> on the GPU then floating point fluctuations would sometimes produce evaluations outside of the valid <code>[t0, t1]</code> region, which would raise a spurious runtime error. This is now fixed. (Thanks <a href="https://github.com/mattlevine22"><code>@mattlevine22</code></a>! <a href="https://redirect.github.com/jax-ml/jax/issues/24807">jax-ml/jax#24807</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/524">#524</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/526">#526</a>)</li> <li>Complex fixes in SDEs (Thanks <a href="https://github.com/Randl"><code>@Randl</code></a>! <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/454">#454</a>)</li> <li>Improvements to errors, warnings, and some typo fixes (Thanks <a href="https://github.com/lockwo"><code>@lockwo</code></a> <a href="https://github.com/ddrous"><code>@ddrous</code></a>! <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/468">#468</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/478">#478</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/495">#495</a>, <a href="https://redirect.github.com/patrick-kidger/diffrax/issues/530">#530</a>)</li> </ul> <h2>New Contributors</h2> <ul> <li><a href="https://github.com/ddrous"><code>@ddrous</code></a> made their first contribution in <a href="https://redirect.github.com/patrick-kidger/diffrax/pull/530">patrick-kidger/diffrax#530</a></li> </ul> <p><strong>Full Changelog</strong>: <a href="https://github.com/patrick-kidger/diffrax/compare/v0.6.0...v0.6.1">https://github.com/patrick-kidger/diffrax/compare/v0.6.0...v0.6.1</a></p> </blockquote> </details> <details> <summary>Commits</summary> <ul> <li><a href="https://github.com/patrick-kidger/diffrax/commit/78531fa2ae15f0e8ce7356148642122a16d3531a"><code>78531fa</code></a> Bump minimum Equinox version to one that is compatible with latest JAX</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/825e4e06a84d15abc523c3d61c9af94fdc48de67"><code>825e4e0</code></a> Fixed where a nonbatchable check was being called.</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/1ae1d58cbbf2535f4e6bd27b5b52e7189c96b86f"><code>1ae1d58</code></a> version bump</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/72deb78fbb30af068c46a7ad3be4c22ed05f2fd9"><code>72deb78</code></a> Updated pre-commit to handle jaxtyping update</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/d3490e6c83402da8c58c70d4158d7f61e9f3e4a6"><code>d3490e6</code></a> Fixes for JAX 0.4.36 which changes the name of an error.</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/3c21d15405b20c07da407cfe34a688a3b4d5ff24"><code>3c21d15</code></a> Updates to the t0==t1 case to handle <code>SubSaveAt(fn=...)</code> and nonstandard dtyp...</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/ebd798064248b4e7102047a032a3e0d89acf5c99"><code>ebd7980</code></a> Save fix for <code>t0==t1</code> (<a href="https://redirect.github.com/patrick-kidger/diffrax/issues/494">#494</a>)</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/965f6b403b7c799067be4f68fbf972e23a817aed"><code>965f6b4</code></a> Compatibility with JAX 0.4.36, which removes ConcreteArray</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/beadc78dc7a58b58aceec18fa3bd86defb2f8c1d"><code>beadc78</code></a> bump doc building pipeline</li> <li><a href="https://github.com/patrick-kidger/diffrax/commit/0cf67d1268283354b3522deede7e7730b0b250e1"><code>0cf67d1</code></a> small fix of docs in all three and a return type in quicsort</li> <li>Additional commits viewable in <a href="https://github.com/patrick-kidger/diffrax/compare/v0.4.1...v0.6.1">compare view</a></li> </ul> </details> <br /> Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) --- <details> <summary>Dependabot commands and options</summary> <br /> You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore <dependency name> major version` will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself) - `@dependabot ignore <dependency name> minor version` will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself) - `@dependabot ignore <dependency name>` will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself) - `@dependabot unignore <dependency name>` will remove all of the ignore conditions of the specified dependency - `@dependabot unignore <dependency name> <ignore condition>` will remove the ignore condition of the specified dependency and ignore conditions </details>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just released yesterday and coincides with our CI failing mysterously, so we need a PR to address these changes or limit the JAX upper version (ideally the former)
Seemingly related to
interp1d
calls from interpax, though a simple call tointerp1d
from interpax and the newest JAX does not return the error so it must be something specific with how we are using itThe text was updated successfully, but these errors were encountered: