Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address possible breaking changes from jax==0.4.36 #1434

Closed
dpanici opened this issue Dec 6, 2024 · 2 comments · Fixed by #1461
Closed

Address possible breaking changes from jax==0.4.36 #1434

dpanici opened this issue Dec 6, 2024 · 2 comments · Fixed by #1461
Labels
bug Something isn't working P3 Highest Priority, someone is/should be actively working on this

Comments

@dpanici
Copy link
Collaborator

dpanici commented Dec 6, 2024

Just released yesterday and coincides with our CI failing mysterously, so we need a PR to address these changes or limit the JAX upper version (ideally the former)

Seemingly related to interp1dcalls from interpax, though a simple call to interp1d from interpax and the newest JAX does not return the error so it must be something specific with how we are using it

@dpanici dpanici added bug Something isn't working P3 Highest Priority, someone is/should be actively working on this labels Dec 6, 2024
@f0uriest
Copy link
Member

f0uriest commented Dec 7, 2024

Most of that issues look like they might be fixed upstream in equinox: patrick-kidger/equinox#907

@f0uriest
Copy link
Member

f0uriest commented Dec 7, 2024

Looks like its actually due to jax-ml/jax#25329

f0uriest added a commit that referenced this issue Dec 12, 2024
Resolves #1434 

Updates the requirements on [jax](https://github.com/jax-ml/jax) and
[diffrax](https://github.com/patrick-kidger/diffrax) to permit the
latest version.
Updates `jax` to 0.4.37
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/jax-ml/jax/releases">jax's
releases</a>.</em></p>
<blockquote>
<h2>JAX v0.4.37</h2>
<p>This is a patch release of jax 0.4.36. Only &quot;jax&quot; was
released at this version.</p>
<ul>
<li>Bug fixes
<ul>
<li>Fixed a bug where <code>jit</code> would error if an argument was
named <code>f</code> (<a
href="https://redirect.github.com/jax-ml/jax/issues/25329">#25329</a>).</li>
<li>Fix a bug that will throw <code>index out of range</code> error in
<code>jax.lax.while_loop</code> if the user registers pytree node class
with
different aux data for the flatten and flatten_with_path.</li>
<li>Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU
v6e.</li>
</ul>
</li>
</ul>
</blockquote>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/jax-ml/jax/blob/main/CHANGELOG.md">jax's
changelog</a>.</em></p>
<blockquote>
<h2>jax 0.4.37 (Dec 9, 2024)</h2>
<p>This is a patch release of jax 0.4.36. Only &quot;jax&quot; was
released at this version.</p>
<ul>
<li>Bug fixes
<ul>
<li>Fixed a bug where <code>jit</code> would error if an argument was
named <code>f</code> (<a
href="https://redirect.github.com/jax-ml/jax/issues/25329">#25329</a>).</li>
<li>Fix a bug that will throw <code>index out of range</code> error in
{func}<code>jax.lax.while_loop</code> if the user register pytree node
class with
different aux data for the flatten and flatten_with_path.</li>
<li>Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU
v6e.</li>
</ul>
</li>
</ul>
<h2>jax 0.4.36 (Dec 5, 2024)</h2>
<ul>
<li>Breaking Changes
<ul>
<li>
<p>This release lands &quot;stackless&quot;, an internal change to JAX's
tracing
machinery. We made trace dispatch purely a function of context rather
than a
function of both context and data. This let us delete a lot of machinery
for
managing data-dependent tracing: levels, sublevels,
<code>post_process_call</code>,
<code>new_base_main</code>, <code>custom_bind</code>, and so on. The
change should only affect
users that use JAX internals.</p>
<p>If you do use JAX internals then you may need to
update your code (see
<a
href="https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f">https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f</a>
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks
your
non-JAX-internals-using code then try the
<code>config.jax_data_dependent_tracing_fallback</code> flag as a
workaround, and if
you need help updating your code then please file a bug.</p>
</li>
<li>
<p>{func}<code>jax.experimental.jax2tf.convert</code> with
<code>native_serialization=False</code>
or with <code>enable_xla=False</code> have been deprecated since July
2024, with
JAX version 0.4.31. Now we removed support for these use cases.
<code>jax2tf</code>
with native serialization will still be supported.</p>
</li>
<li>
<p>In <code>jax.interpreters.xla</code>, the <code>xb</code>,
<code>xc</code>, and <code>xe</code> symbols have been removed
after being deprecated in JAX v0.4.31. Instead use <code>xb =
jax.lib.xla_bridge</code>,
<code>xc = jax.lib.xla_client</code>, and <code>xe =
jax.lib.xla_extension</code>.</p>
</li>
<li>
<p>The deprecated module <code>jax.experimental.export</code> has been
removed. It was replaced
by {mod}<code>jax.export</code> in JAX v0.4.30. See the <a
href="https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export">migration
guide</a>
for information on migrating to the new API.</p>
</li>
<li>
<p>The <code>initial</code> argument to
{func}<code>jax.nn.softmax</code> and
{func}<code>jax.nn.log_softmax</code>
has been removed, after being deprecated in v0.4.27.</p>
</li>
<li>
<p>Calling <code>np.asarray</code> on typed PRNG keys (i.e. keys
produced by :func:<code>jax.random.key</code>)
now raises an error. Previously, this returned a scalar object
array.</p>
</li>
<li>
<p>The following deprecated methods and functions in
{mod}<code>jax.export</code> have
been removed:</p>
<ul>
<li><code>jax.export.DisabledSafetyCheck.shape_assertions</code>: it had
no effect
already.</li>
<li><code>jax.export.Exported.lowering_platforms</code>: use
<code>platforms</code>.</li>
<li><code>jax.export.Exported.mlir_module_serialization_version</code>:
use <code>calling_convention_version</code>.</li>
</ul>
</li>
</ul>
</li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/jax-ml/jax/commit/ffb07cdadb5dc3bc43485cf041dbc2b43136109e"><code>ffb07cd</code></a>
Update versions for v0.4.37 release.</li>
<li><a
href="https://github.com/jax-ml/jax/commit/95892fdac86524151b6dadd7d8bedbf915f1500f"><code>95892fd</code></a>
Use private names for args in api_util to avoid shadowing kwargs
keys.</li>
<li><a
href="https://github.com/jax-ml/jax/commit/65b60884114261549ffc2eb937162bdeaa493928"><code>65b6088</code></a>
Avoid index out of range error in carry structure check</li>
<li><a
href="https://github.com/jax-ml/jax/commit/259194a69f52a06847a9ff11eb268072e91fd65f"><code>259194a</code></a>
[Pallas] Fix shard_axis in dma_start interpret mode rule.</li>
<li><a
href="https://github.com/jax-ml/jax/commit/7e6620a57775084dfa8d438ae4fd27f3ef365018"><code>7e6620a</code></a>
JAX release 0.4.36.</li>
<li><a
href="https://github.com/jax-ml/jax/commit/23d5c10ff0704f66ad7ec65a8cdcd09bd2420591"><code>23d5c10</code></a>
[Mosaic:TPU] Fix fully replicated relayout</li>
<li><a
href="https://github.com/jax-ml/jax/commit/2a4a0e8d6fb36b59f9c6f24e0018d42c8c8d8ee9"><code>2a4a0e8</code></a>
[jax:custom_partitioning] Implement SdyShardingRule to support</li>
<li><a
href="https://github.com/jax-ml/jax/commit/f73fa7a7ad64b2f15e8669beed14600704287b93"><code>f73fa7a</code></a>
Merge pull request <a
href="https://redirect.github.com/jax-ml/jax/issues/25290">#25290</a>
from jakevdp:reduction-where</li>
<li><a
href="https://github.com/jax-ml/jax/commit/a71f9a62e6f67640a4b0578d042b07792fcf407a"><code>a71f9a6</code></a>
Merge pull request <a
href="https://redirect.github.com/jax-ml/jax/issues/25271">#25271</a>
from jakevdp:fix-vector-norm</li>
<li><a
href="https://github.com/jax-ml/jax/commit/e20a483befbb80bbf782b931ec57a44c78c313b8"><code>e20a483</code></a>
[JAX] Add end-to-end execution support in colocated Python API</li>
<li>Additional commits viewable in <a
href="https://github.com/jax-ml/jax/compare/jax-v0.4.24...jax-v0.4.37">compare
view</a></li>
</ul>
</details>
<br />

Updates `diffrax` to 0.6.1
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/patrick-kidger/diffrax/releases">diffrax's
releases</a>.</em></p>
<blockquote>
<h2>Diffrax v0.6.1</h2>
<h3>Features</h3>
<ul>
<li>
<p>Compatibility with JAX 0.4.36.</p>
</li>
<li>
<p>New solvers! Added stochastic Runge--Kutta methods for solving the
underdamped Langevin equation. We now have:</p>
<ul>
<li><code>diffrax.AbstractFosterLangevinSRK</code></li>
<li><code>diffrax.ALIGN</code></li>
<li><code>diffrax.QUICSORT</code></li>
<li><code>diffrax.ShOULD</code></li>
</ul>
<p>and these are used with the corresponding</p>
<ul>
<li><code>diffrax.UnderdampedLangevinDriftTerm</code></li>
<li><code>diffrax.UnderdampedLangevinDiffusionTerm</code></li>
</ul>
<p>huge thanks to <a
href="https://github.com/andyElking"><code>@​andyElking</code></a> for
carefully implementing all of these, which was a huge technical task.
(<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/453">#453</a>
and 2000 new lines of code!) See <a
href="https://docs.kidger.site/diffrax/examples/underdamped_langevin_example/">the
Underdamped Langevin Diffusion example</a> for more on how to use
these.</p>
</li>
</ul>
<h3>Bugfixes</h3>
<ul>
<li>If <code>t0 == t1</code> and we have <code>SaveAt(ts=...)</code>
then we now correctly output <code>len(ts)</code> copies of
<code>y0</code>. (Thanks <a
href="https://github.com/dkweiss31"><code>@​dkweiss31</code></a>! <a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/488">#488</a>,
<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/494">#494</a>)</li>
<li>When using <code>diffrax.VirtualBrownianTree</code> on the GPU then
floating point fluctuations would sometimes produce evaluations outside
of the valid <code>[t0, t1]</code> region, which would raise a spurious
runtime error. This is now fixed. (Thanks <a
href="https://github.com/mattlevine22"><code>@​mattlevine22</code></a>!
<a
href="https://redirect.github.com/jax-ml/jax/issues/24807">jax-ml/jax#24807</a>,
<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/524">#524</a>,
<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/526">#526</a>)</li>
<li>Complex fixes in SDEs (Thanks <a
href="https://github.com/Randl"><code>@​Randl</code></a>! <a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/454">#454</a>)</li>
<li>Improvements to errors, warnings, and some typo fixes (Thanks <a
href="https://github.com/lockwo"><code>@​lockwo</code></a> <a
href="https://github.com/ddrous"><code>@​ddrous</code></a>! <a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/468">#468</a>, <a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/478">#478</a>,
<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/495">#495</a>,
<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/530">#530</a>)</li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a href="https://github.com/ddrous"><code>@​ddrous</code></a> made
their first contribution in <a
href="https://redirect.github.com/patrick-kidger/diffrax/pull/530">patrick-kidger/diffrax#530</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/patrick-kidger/diffrax/compare/v0.6.0...v0.6.1">https://github.com/patrick-kidger/diffrax/compare/v0.6.0...v0.6.1</a></p>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/78531fa2ae15f0e8ce7356148642122a16d3531a"><code>78531fa</code></a>
Bump minimum Equinox version to one that is compatible with latest
JAX</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/825e4e06a84d15abc523c3d61c9af94fdc48de67"><code>825e4e0</code></a>
Fixed where a nonbatchable check was being called.</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/1ae1d58cbbf2535f4e6bd27b5b52e7189c96b86f"><code>1ae1d58</code></a>
version bump</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/72deb78fbb30af068c46a7ad3be4c22ed05f2fd9"><code>72deb78</code></a>
Updated pre-commit to handle jaxtyping update</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/d3490e6c83402da8c58c70d4158d7f61e9f3e4a6"><code>d3490e6</code></a>
Fixes for JAX 0.4.36 which changes the name of an error.</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/3c21d15405b20c07da407cfe34a688a3b4d5ff24"><code>3c21d15</code></a>
Updates to the t0==t1 case to handle <code>SubSaveAt(fn=...)</code> and
nonstandard dtyp...</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/ebd798064248b4e7102047a032a3e0d89acf5c99"><code>ebd7980</code></a>
Save fix for <code>t0==t1</code> (<a
href="https://redirect.github.com/patrick-kidger/diffrax/issues/494">#494</a>)</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/965f6b403b7c799067be4f68fbf972e23a817aed"><code>965f6b4</code></a>
Compatibility with JAX 0.4.36, which removes ConcreteArray</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/beadc78dc7a58b58aceec18fa3bd86defb2f8c1d"><code>beadc78</code></a>
bump doc building pipeline</li>
<li><a
href="https://github.com/patrick-kidger/diffrax/commit/0cf67d1268283354b3522deede7e7730b0b250e1"><code>0cf67d1</code></a>
small fix of docs in all three and a return type in quicsort</li>
<li>Additional commits viewable in <a
href="https://github.com/patrick-kidger/diffrax/compare/v0.4.1...v0.6.1">compare
view</a></li>
</ul>
</details>
<br />


Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore <dependency name> major version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's major version (unless you unignore this specific
dependency's major version or upgrade to it yourself)
- `@dependabot ignore <dependency name> minor version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's minor version (unless you unignore this specific
dependency's minor version or upgrade to it yourself)
- `@dependabot ignore <dependency name>` will close this group update PR
and stop Dependabot creating any more for the specific dependency
(unless you unignore this specific dependency or upgrade to it yourself)
- `@dependabot unignore <dependency name>` will remove all of the ignore
conditions of the specified dependency
- `@dependabot unignore <dependency name> <ignore condition>` will
remove the ignore condition of the specified dependency and ignore
conditions


</details>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P3 Highest Priority, someone is/should be actively working on this
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants