-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Torchsde to diffrax conversion #317
Comments
Pyplot code for reproduction:
|
One thing that does jump out is that you're using Other than that, do also note that you're using an Ito solver in torchsde, but a Stratonovich solver in Diffrax. In this particular example, your noise is such that Ito and Stratonovich should actually converge to the same result; just be aware that this is the case. |
Great! Glad that fixed things. In terms of speed -- that could be an Euler/SRK difference, but for something as big as 10x it's probably more something to do with the step size controller. In Diffrax right now you're just using a simple I-controller, but for an SDE solve you usually want at least a PI-controller. Take a look at the docs for Likewise you might be using different tolereances etc. I'd encourage checking exactly where Diffrax is making steps by using |
Thank you for the suggestion. I tried all of the following PID Controllers from the docs:
each for various When I performed the same analysis with torchsde SRK, I found that the step size stabilized at Do you think we can chalk this up to the solver or is there something else in the adaptive step size controller? |
Hmm. I would suggest trying a different solver, e.g. just |
Thanks! I tried The PID Controller I used for reproduction purposes: |
BTW, do you think there's a faster way to do
|
FWIW Diffrax is almost always a fair bit faster than torchsde, so the results you're seeing can probably be improved! (Making it an apples-to-apples comparison of step size and solver would help.) For batching, use |
Resolved! I got about the same performance as torchsde with just the |
Oh, that is curious! Do you have a MWE you can share? And on what hardware? |
On an A100, this is Trevor McCourt's benchmark with the
that runs @2.7s ish on my A100 instance with |
The main changes needed to make this happen are in patrick-kidger/equinox#548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317.
Alright, thank you for the MWE! This ended up being an interesting one to track down. You can find the details in #320 and patrick-kidger/equinox#548. Basically, solving an SDE with Using the above PRs and on a V100, I find that the speeds for your benchmark are identical regardless of the value of |
Thanks for addressing this! Will switch over to the new eqx version. |
The main changes needed to make this happen are in patrick-kidger/equinox#548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317.
I'm trying to convert this simple SDE that converges in torchsde to run on diffrax and produce corroborating results but I'm seeing very different results on diffrax with all of the solvers available.
The torchsde code:
And averaging over the batch size dimension produces the following correct result:
However, converting the same SDE to diffrax appears to produce very different results, and takes much more time.
This was the diffrax code I ran to implement the same SDE:
Am I using diffrax wrong?
The text was updated successfully, but these errors were encountered: