Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Torchsde to diffrax conversion #317

Closed
Sinestro38 opened this issue Oct 2, 2023 · 14 comments
Closed

[Question] Torchsde to diffrax conversion #317

Sinestro38 opened this issue Oct 2, 2023 · 14 comments

Comments

@Sinestro38
Copy link

I'm trying to convert this simple SDE that converges in torchsde to run on diffrax and produce corroborating results but I'm seeing very different results on diffrax with all of the solvers available.

The torchsde code:

import numpy as np
import torch
import torchsde

def arbitrary_force_fn(grid_x, l_bar = 1, rho = 1.5, lambda_th = 1, phi_ext = np.pi/16 ):
    return 2 * l_bar/(lambda_th**2) * (grid_x - lambda_th * rho * torch.sin(grid_x/lambda_th + phi_ext))

batch_size = 100000
init_states = torch.broadcast_to(torch.Tensor([0.0, -1]), (batch_size, 2))
times = torch.linspace(start=0, end=20, steps=1000)

class MyParticle(torch.nn.Module):
    def __init__(self, force_fn):
        super().__init__()
        self.force_fn = force_fn
        self.noise_type = "diagonal"
        self.sde_type = "ito"
        self.noise_vec = torch.nn.Parameter(torch.tensor([0, np.sqrt(2)]))

    def f(self, _, y):
        x = y[..., 0]
        p = y[..., 1]
        return torch.stack([p, -self.force_fn(x) - p], dim=-1)

    def g(self, _, y):
        return torch.broadcast_to(self.noise_vec, y.shape)

sde = MyParticle(arbitrary_force_fn)
device_cuda = torch.device("cuda")
with torch.no_grad():
    solns = torchsde.sdeint(sde.to(device_cuda), init_states.to(device_cuda), times.to(device_cuda), adaptive=True)
solns = torch.swapaxes(solns, 0, 1) # of shape (batch_size, times, 2) after swapping

And averaging over the batch size dimension produces the following correct result:
image

However, converting the same SDE to diffrax appears to produce very different results, and takes much more time.
image

This was the diffrax code I ran to implement the same SDE:

import jax.random as jr
import jax.numpy as jnp
from diffrax import diffeqsolve, ReversibleHeun, ControlTerm, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree, PIDController

def arbitrary_force_fn(grid_x, l_bar = 1, rho = 1.5, lambda_th = 1, phi_ext = jnp.pi/16 ):
    return 2 * l_bar/(lambda_th**2) * (grid_x - lambda_th * rho * jnp.sin(grid_x/lambda_th + phi_ext))

batch_size = 100000
init_states = jnp.broadcast_to(jnp.array([0, -1]), (batch_size, 2))
t0, t1 = 0, 20
times = jnp.linspace(t0, t1, 1000)

noise_vec = jnp.asarray([0, jnp.sqrt(2)])

def drift(t, y, args):
    x = y[... , 0]
    p = y[..., 1]
    return jnp.stack([p, -arbitrary_force_fn(x) - p], axis=-1)

def diffusion(t, y, args):
    return jnp.broadcast_to(noise_vec, y.shape)

brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(batch_size, 2,), key=jr.PRNGKey(0))
sde = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver, saveat = ReversibleHeun(), SaveAt(ts=times)
stepsize_controller = PIDController(rtol=1e-3, atol=1e-3)
sol = diffeqsolve(sde, solver, t0, t1, dt0=0.001, y0=init_states, saveat=saveat, stepsize_controller=stepsize_controller)
sol_y = sol.ys.swapaxes(0, 1) # of shape (batch_size, times, 2) after swapping

Am I using diffrax wrong?

@Sinestro38
Copy link
Author

Pyplot code for reproduction:

avg_soln = jnp.mean(sol_y, axis=0)
fig, axs = plt.subplots(nrows=2, sharex=True)
axs[0].plot(times, avg_soln[:, 0])
axs[1].plot(times, avg_soln[:, 1])
axs[1].set_xlabel("Time")
axs[0].set_ylabel(r"$\langle y \rangle$")
axs[1].set_ylabel(r"$\langle v \rangle$")

@patrick-kidger
Copy link
Owner

One thing that does jump out is that you're using noise_type = "diagonal" in torchsde (which would correspond to WeaklyDiagonalControlTerm in Diffrax), but ControlTerm in Diffrax (which would correspond to noise_type = "general" in torchsde).

Other than that, do also note that you're using an Ito solver in torchsde, but a Stratonovich solver in Diffrax. In this particular example, your noise is such that Ito and Stratonovich should actually converge to the same result; just be aware that this is the case.

@Sinestro38
Copy link
Author

Thanks! I got the correct solution using the WeaklyDiagonalControlTerm. However, the Euler solver I used with constant step size was 10x slower than torchsde, but I assume that's because the SRK method which torch uses just isn't implemented in diffrax yet.

image

@patrick-kidger
Copy link
Owner

Great! Glad that fixed things.

In terms of speed -- that could be an Euler/SRK difference, but for something as big as 10x it's probably more something to do with the step size controller. In Diffrax right now you're just using a simple I-controller, but for an SDE solve you usually want at least a PI-controller. Take a look at the docs for diffrax.PIDController for more info on how to set that up.

Likewise you might be using different tolereances etc. I'd encourage checking exactly where Diffrax is making steps by using saveat=diffrax.SaveAt(steps=True), and then checking sol.ts. (I don't think torchsde has an analogous way of getting the output -- one of the many features Diffrax adds relative to torchsde -- but you could easily insert some print statements into its core integration loop.)

@Sinestro38
Copy link
Author

Sinestro38 commented Oct 2, 2023

Thank you for the suggestion. I tried all of the following PID Controllers from the docs:

PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0)
PIDController(pcoeff=0.3, icoeff=0.3, dcoeff=0)
PIDController(pcoeff=0.2, icoeff=0.4, dcoeff=0)
PIDController(pcoeff=0.1, icoeff=0.3, dcoeff=0)

each for various rtol/atol configs but found that each time the dt consistently decays and invariably hits the max_steps wall of the solver. Increasing the max_steps didn't do much to help since the dt slowly keeps decaying which seems unusual (starts at 0.01 but keeps decaying past 1e-3 and further). This was with ReversibleHeun() since it supports adaptive step size. So the integration never finished without erroring out but I obtained the dt through inserting a jax.debug.print() in adaptive.py.

When I performed the same analysis with torchsde SRK, I found that the step size stabilized at 0.03744 which seemed much more reasonable.

Do you think we can chalk this up to the solver or is there something else in the adaptive step size controller?

@patrick-kidger
Copy link
Owner

Hmm. I would suggest trying a different solver, e.g. just Heun. What does that produce?
ReversibleHeun is a fairly unstable solver, so it's very possible that it's doing something funny.

@Sinestro38
Copy link
Author

Thanks! I tried Heun and the dt appeared to stabilize at around 0.00868. The solution is also correct. torchsde on GPU is still around 5x faster but I think the ~4.3x larger step size has a lot to do with that.

The PID Controller I used for reproduction purposes: PIDController(rtol=1e-3, atol=1e-4, pcoeff=0.1, icoeff=0.3, dcoeff=0)

@Sinestro38
Copy link
Author

BTW, do you think there's a faster way to do diffeqsolve on a batch of trajectories? I'm trying to average over a batch_size of 100k trajectories to get the average solution. The current way I'm accomplishing this is just by setting the shape of VirtualBrownianTree to be (batch_size, ...) and the shape of the y0 I pass in to also be of shape (batch_size, ...).

def arbitrary_force_fn(grid_x, l_bar = 1, rho = 1.5, lambda_th = 1, phi_ext = jnp.pi/16 ):
    return 2 * l_bar/(lambda_th**2) * (grid_x - lambda_th * rho * jnp.sin(grid_x/lambda_th + phi_ext))

batch_size = 100000
init_states = jnp.broadcast_to(jnp.array([0, -1]), (batch_size, 2))
t0, t1 = 0, 20
times = jnp.linspace(t0, t1, 1000)

noise_vec = jnp.asarray([0, jnp.sqrt(2)])

def drift(t, y, args):
    x = y[... , 0]
    p = y[..., 1]
    return jnp.stack([p, -arbitrary_force_fn(x) - p], axis=-1)

def diffusion(t, y, args):
    return jnp.broadcast_to(noise_vec, y.shape)

brownian_motion = VirtualBrownianTree(t0, t1, tol=(times[1]-times[0]), shape=(batch_size, 2,), key=jr.PRNGKey(0))
sde = MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))
solver, saveat = Heun(), SaveAt(ts=times)
stepsize_controller = PIDController(rtol=1e-3, atol=1e-4, pcoeff=0.1, icoeff=0.3, dcoeff=0)

sol = diffeqsolve(sde, solver, t0, t1, dt0=0.001, y0=init_states, saveat=saveat, stepsize_controller=stepsize_controller)
sol_y = sol.ys.swapaxes(0, 1) # of shape (batch_size, times, 2) after swapping

avg_soln = jnp.mean(sol_y, axis=0)

@patrick-kidger
Copy link
Owner

FWIW Diffrax is almost always a fair bit faster than torchsde, so the results you're seeing can probably be improved! (Making it an apples-to-apples comparison of step size and solver would help.)

For batching, use jax.vmap.

@Sinestro38
Copy link
Author

Resolved! I got about the same performance as torchsde with just the Heun() solver now, the issue was with max_steps. diffeqsolve ran around 6x faster with max_steps=None, it's 4096 by default. I think the counter must be slowing things down.

@patrick-kidger
Copy link
Owner

Oh, that is curious! Do you have a MWE you can share? And on what hardware?

@Sinestro38
Copy link
Author

Sinestro38 commented Oct 4, 2023

On an A100, this is Trevor McCourt's benchmark with the max_steps=None

import diffrax as dx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import time

nu = 1
init_conds = jr.uniform(jr.PRNGKey(0), minval=-1, maxval=1, shape=(2,))
batch_size = 10000
soln_times = np.linspace(start=0, stop=20, num=1000)
rtol = 0.005

nu, device, solver, step_controller = nu, jax.devices('gpu')[0], dx.Heun(), dx.PIDController(rtol=rtol, atol=0.1 * rtol, pcoeff=0.1, icoeff=0.3, dcoeff=0)

def drift(t, y, args):
    x = y[... , 0]
    p = y[..., 1]
    return jnp.stack([p, -nu * x - p], axis=-1)

noise_vec = jnp.asarray([0, jnp.sqrt(2)])

def diffusion(t, y, args):
    return jnp.broadcast_to(noise_vec, y.shape)

with jax.default_device(jax.devices('gpu')[0]):
    init_state = jnp.broadcast_to(init_conds, (batch_size, 2))

    brownian_motion = dx.VirtualBrownianTree(soln_times[0], soln_times[-1], tol=(soln_times[1] - soln_times[0])/100, shape=(batch_size, 2,), key=jax.random.PRNGKey(0))
    sde = dx.MultiTerm(dx.ODETerm(drift), dx.WeaklyDiagonalControlTerm(diffusion, brownian_motion))

    def precompile_solve():
        return dx.diffeqsolve(sde, solver, soln_times[0], soln_times[-1], dt0=0.001, y0=init_state,
                                saveat=saveat,
                                stepsize_controller=step_controller, max_steps=None)

    saveat = dx.SaveAt(ts=soln_times)

    precompiled_runner = jax.jit(precompile_solve).lower().compile()

    start_time = time.time()
    sol = precompiled_runner().ys.block_until_ready()
    end_time = time.time()
sols, jax_time = np.array(sol.swapaxes(0, 1)), end_time - start_time

print("jax_time: ", jax_time)

that runs @2.7s ish on my A100 instance with max_steps=None but removing the kwarg and letting it default to 4096 balloons the runtime to @14.4s.

patrick-kidger added a commit that referenced this issue Oct 9, 2023
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
@patrick-kidger
Copy link
Owner

Alright, thank you for the MWE! This ended up being an interesting one to track down. You can find the details in #320 and patrick-kidger/equinox#548.

Basically, solving an SDE with max_steps!=None was hitting an old slow path that existed for compatibility with older versions of JAX.

Using the above PRs and on a V100, I find that the speeds for your benchmark are identical regardless of the value of max_steps. If you get the chance (and are comfortable installing Equinox+Diffrax from these GitHub branches) it'd be great to verify that you see the same results.

@Sinestro38
Copy link
Author

Thanks for addressing this! Will switch over to the new eqx version.

patrick-kidger added a commit that referenced this issue Oct 13, 2023
The main changes needed to make this happen are in
patrick-kidger/equinox#548
and as such this commit is fairly small -- it declares a dependency on
a new (as yet unreleasd) version of Equinox, and removes the
compatibility shim that was there before.

**How things used to be.**

To explain what's going on here a little more carefully: JAX only
recently added support for communicating to a `custom_vjp` which
input arguments were being perturbed, and which output cotangents were
symbolic zeros. Prior to that, a `custom_vjp` basically just had to
differentiate all inexact arrays.

However, there are some nondifferentiable inexact arrays, in the
sense that attempting to differentiate them will raise an error.
Solving SDEs has one such array: the nondifferentiable input to a
VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to
reflect the fact that Brownian motion is nondifferentiable.

So it used to be the case that the `custom_vjp` underlying
`RecursiveCheckpointAdjoint` would differentiate the overall
make-a-step function with respect to all inexact arrays,
including the time variable, hit the `nondifferentiable` guard, and
crash.

One (unsafe) fix would have just been to remove the `nondifferentiable`
guard. In practice I previously took the slower, safer option: silently
switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The
latter is much less efficient, but uses no `custom_vjp`, and thus used
the perturbation and symbolic-zero propagation rules already present
in JAX's AD machinery, and was thus safe to use here.

**How things are now.**

And so, what has now changed: JAX has now added support for tracking
which inputs are perturbed, and which cotangents are symbolic zeros.

The Equinox PR above uses this functionality to determine which parts
of the carry need to be differentiated; no longer is it just "all
inexact arrays".

And thus this Diffrax PR removes the compatibility shim that is no
longer needed.

**Implications.**

SDE solving should now be much faster. In particular this fixes the
speed issue reported in #317.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants