diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 814c90d9..48b407a1 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -10,7 +10,7 @@ import jax.tree_util as jtu from jax.typing import ArrayLike -from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint +from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation @@ -633,22 +633,6 @@ def diffeqsolve( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) - # TODO: remove these lines. - # - # These are to work around an edge case: on the backward pass, - # RecursiveCheckpointAdjoint currently tries to differentiate the overall - # per-step function wrt all floating-point arrays. In particular this includes - # `state.tprev`, which feeds into the control, which feeds into - # VirtualBrownianTree, which can't be differentiated. - # We're waiting on JAX to offer a way of specifying which arguments to a - # custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more - # precisely determine what to differentiate wrt. - # - # We don't replace this in the case of an unsafe SDE because - # RecursiveCheckpointAdjoint will raise an error in that case anyway, so we - # should let the normal error be raised. - if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms): - adjoint = DirectAdjoint() if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): raise ValueError( diff --git a/setup.py b/setup.py index 908c1d58..618c8a79 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.9" -install_requires = ["jax>=0.4.13", "equinox>=0.10.11"] +install_requires = ["jax>=0.4.13", "equinox>=0.11.1"] setuptools.setup( name=name,