-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sparse jacobian solve #545
Comments
Yup, both JAX and Diffrax are optimized around dense matrices. We do have some iterative linear solvers you can use (e.g. Taking a quick look at the benchmark you have:
That aside I'm curious to see a little more about the library you're writing here. We learnt a lot of lessons in writing Diffrax, many of them new to Diffrax, e.g.:
So if you're tackling the problem of writing an equivalent library in a new ecosystem then I'd be happy to offer some thoughts on how to ensure it builds on prior art both in the JAX ecosystem and (I am also reasonably familiar with) the Julia ecosystem. I'm also curious if your ambitions include GPU support and reverse-mode (or higher-order) autodiff? (Which so far as I can see aren't there atm?) I think these are becoming table stakes for new numerical software, and history has proven they are hard to retrofit on to older code that was not originally written with it in mind. (Supporting these use cases are a big part of why Diffrax exists!) |
Thanks @patrick-kidger. Good call on the For the iterative solvers in Generally it would be very useful to get your thoughts on the current state-of-the art for ode solvers and where you see the interesting areas for growth. It would be great to organise a call sometime in the new year if you are amenable. GPU support is definitely on my roadmap, diffsol is designed to be generic over vector, matrix and linear solver types, so I can swap in and out different linear algebra libraries. At the moment I'm hampered by the fact that there are no decent GPU-based linear algebra libraries in rust, but I'm also considering wrapping a C library as a temporary solution until the rust ecosystem matures. I've implemented the solution of the continuous adjoint equations, with checkpointing. Its not really mentioned in the docs yet because I'm still playing around with the API. I don't do any backprop through the solver however. There is currently a PR into nightly rust that will add autodiff via Enzyme (rust-lang/rust#124509), and once this is done I'm going to try and implement this one. I had a read through your Term system, which was really interesting, I've done something similar (in purpose, the structure is a bit different) using a system of operator traits (nonlinear, linear etc), and then defining the ode equations as a set of operators. Different solvers can place bounds on these types that define which form of equations they are able to solve, so users get compile-time errors if their equations aren't suitable for the solver used. The way that you've incorporated SDEs into the Term system is great, I'd have to have a think how this could fit into diffsol. Do you see much usage of SDEs with diffrax (compared with deterministic equations)? |
I'd be very happy to take a PR on BDF! No plans to implement it myself right now. On lineax and sparse solves -- haha I think you're right. This should 'just work' at the moment simply by passing the appropriate linear solve to the differential equation solver. (Since under-the-hood it will create a As for the current SOTA -- very happy to have a call. I can see you're also in Oxford, and I'll be passing through London early in the new year if by any chance you're in town then too. When it comes to backpropagation, I'd strongly discourage using the continuous adjoint. This is an almost universally bad idea with essentially no redeeming qualities. 😄 The best option is almost always discretise-then-optimise with recursive checkpointing. (See Equinox's checkpointed while loop.) In the future this may be supplanted by algebraically reversible differential equation solvers but they're still an open research topic. I'll be really curious to see what autodiff looks like in Rust at the language level, I've mostly interacted with framework-level autodiff (PyTorch, JAX) and had bad experiences with language-level autodiff before (Julia). As for SDEs yup, I get quite a lot of users of these. Diffrax was actually originally a research project to see if we could write out the numerics for ODEs+SDEs in a single unified way (and we can), so they've been in since the start! |
while I won't be able to do a BDF PR to diffrax myself, I could get someone else started as I've implemented BDF in jax before here. Its not the greatest code to get an understanding of the algorithm due to the hoops that jax makes you jump through, the scipy implementation is easier to follow. @BradyPlanden has been doing some recent improvements to the jax bdf code in pybamm, so I'll label him on this in case he is interested in contributing to diffrax |
Can the implicit solvers in diffrax use a sparse matrix solve for the jacobian? I'm putting together a benchmark with a few different ode solvers, including diffrax, and the problem in question is stiff (its based on the robertson ODE problem from your examples) and has a small block-diagonal jacobian structure (each block is 3x3 and the total matrix size goes up to 10,000, so very sparse). Diffrax is slowing down significantly at larger matrix sizes and my assumption is that you are using a dense linear solver? Is there a way of swapping this out for a sparse solver? My understanding is that sparse matrix support is still rather experimental for JAX.
You can see the benchmark and results here: https://martinjrobins.github.io/diffsol/benchmarks/python.html. Please feel free to let me know if I'm not using diffrax correctly. I've not used it in anger before, so its entirely possible I'm doing something stupid.
The text was updated successfully, but these errors were encountered: