Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: DirectAdjoint is faster than RecursiveCheckpointAdjoint? #549

Open
CoastEgo opened this issue Jan 1, 2025 · 20 comments
Open

Question: DirectAdjoint is faster than RecursiveCheckpointAdjoint? #549

CoastEgo opened this issue Jan 1, 2025 · 20 comments

Comments

@CoastEgo
Copy link

CoastEgo commented Jan 1, 2025

Hi,
According to the suggestions in adjoints docs, the RecursiveCheckpointAdjoint method, given enough checkpoints, should be faster than DirectAdjoint. But In my practice, it turns out that DirectAdjoint is faster. Is there wrong in my understanding?

Example

An example can be shown using the neural_cde tutorial:

  • If use the default setting: adjoint = RecursiveCheckpointAdjoint(), the total run time (without the first step/compilation time ) is 1.38s
  • if use adjoint = RecursiveCheckpointAdjoint(checkpoints=4096), the run time is 1.18s
  • if use 'adjoint = diffrax.DirectAdjoint()', the run time is 0.77s

Environment

  • jax == 0.4.29 jaxlib == 0.4.29
  • diffrax == 0.6.0
  • platform: CPU
@johannahaffner
Copy link
Contributor

Hi,

what exactly are you benchmarking? The runtime of main?

Comparing the runtimes of NeuralCDE.__call__ with RecursiveCheckpointAdjoint() and DirectAdjoint() I get equivalent times in microbenchmarks (1.93 ms for the direct adjoint and 1.94 ms for the recursive one, mean of 1000 loops). Evaluating their gradients takes the same amount of time as well (8.35 vs 8.33 ms, mean of 100 loops).

Both times I used sample_ts, sample_coeffs as input, same as done for the plotting section of the example.

@CoastEgo
Copy link
Author

CoastEgo commented Jan 1, 2025

Hi,
I just simply count the runtime of the training steps in main(), using the command below

    total_time = 0
    for step, data_i in zip(
        range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
    ):  
        
        start = time.time()
        bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
        end = time.time()
        time_i = end - start
        if step > 0: # don't count the compilation time
            total_time += time_i
        print(
            f"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
    print(f'total time: {total_time}')

If I benchmark the make_step() function using this command.

%timeit jax.block_until_ready(make_step(model, data_i, opt_state))

The result is 64.2ms for RecursiveCheckpointAdjoint() and 35.1ms for DirectAdjoint().
It's really confusing

@johannahaffner
Copy link
Contributor

Have you called make_step before, to ensure that everything is compiled? Something like this

run_fn = eqx.filter_jit(fn)
_ = run_fn(inputs)

%timeit run_fn(inputs).block_until_ready()

@CoastEgo
Copy link
Author

CoastEgo commented Jan 1, 2025

Yes, I add this command below the training steps like this.

    for step, data_i in zip(
        range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
    ):  
        bxe, acc, model, opt_state = make_step(model, data_i, opt_state)

    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))

And the make_step() function inside this notebook is jitted with the @eqx.filter_jit

@johannahaffner
Copy link
Contributor

I cannot reproduce this, I still get equivalent runtimes if I'm adding a %timeit exactly where you do. (134 and 137 ms, mean of 10 loops.)

How/where do you specify which adjoint to use?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 1, 2025

@CoastEgo -- do you have a copy-pastable MWE? (E.g. by starting with the neural CDE example and then minimising it down to something that fits in a GitHub message.) Just to be sure we're all running exactly the same code!

Also, what version of Equinox are you using? The underlying implementation of the while loops (which these adjoint methods use) belong to Equinox.

It's definitely expected that RecursiveCheckpointAdjoint should be the better choice!

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 1, 2025

It's definitely expected that RecursiveCheckpointAdjoint should be the better choice!

Even for such a small example? The solver just takes 50 steps.

@patrick-kidger
Copy link
Owner

It's definitely expected that RecursiveCheckpointAdjoint should be the better choice!

Even for such a small example? The solver just takes 50 steps.

Yup! In fact especially so. The cost of DirectAdjoint actually grows primarily with max_steps, not with the number of steps actually taken.

@CoastEgo
Copy link
Author

CoastEgo commented Jan 1, 2025

Sorry for the confusion! Here is the code

click here
import math


import diffrax

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax  # https://github.com/deepmind/optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP
    data_size: int
    hidden_size: int

    def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.data_size = data_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func
    linear: eqx.nn.Linear
    adjoint_state: int
    def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
        self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
        self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
        self.adjoint_state = adjoint_state
    def __call__(self, ts, coeffs):
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = None

        if self.adjoint_state == 0:
            adjoint = diffrax.RecursiveCheckpointAdjoint()
        else:
            adjoint = diffrax.DirectAdjoint()
        # adjoint = diffrax.DirectAdjoint()
        
        y0 = self.initial(control.evaluate(ts[0]))
        saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
            adjoint=adjoint,
        )
        (prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
        return prediction

def get_data(dataset_size, add_noise, *, key):
    theta_key, noise_key = jr.split(key, 2)
    length = 100
    theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
    y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
    ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
    matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
    ys = jax.vmap(
        lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
    )(y0, ts)
    ys = jnp.concatenate([ts[:, :, None], ys], axis=-1)  # time is a channel
    ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
    if add_noise:
        ys = ys + jr.normal(noise_key, ys.shape) * 0.1
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
    labels = jnp.zeros((dataset_size,))
    labels = labels.at[: dataset_size // 2].set(1.0)
    _, _, data_size = ys.shape
    return ts, coeffs, labels, data_size


def main(
    dataset_size=256,
    add_noise=False,
    batch_size=32,
    lr=1e-2,
    steps=20,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)

    ts, coeffs, labels, data_size = get_data(
        dataset_size, add_noise, key=train_data_key
    )
    # Training loop like normal.

    @eqx.filter_jit
    def loss(model, ti, label_i, coeff_i):
        pred = jax.vmap(model)(ti, coeff_i)
        # Binary cross-entropy
        bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((pred > 0.5) == (label_i == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model, data_i, opt_state):
        ti, label_i, *coeff_i = data_i
        (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state

    optim = optax.adam(lr)
    data_all = (ts, labels) + coeffs
    data_i = (a[:batch_size] for a in data_all)

    print('recursive checkpoint adjoint')
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
    
    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))

    data_i = (a[:batch_size] for a in data_all)
    print('direct adjoint')
    optim = optax.adam(lr)
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = make_step(model, data_i, opt_state)

    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()

This will give a result with the version of equinox == 0.11.4

#recursive checkpoint adjoint
#74.4 ms ± 95.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#46.5 ms ± 43.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

When I upgrade all of my packages jax==0.4.38 jaxlib==0.4.38 diffrax==0.6.2 equinox==0.11.11, this problem disappears.
This will give a result like this, same performance but slower 😢

#recursive checkpoint adjoint
#160 ms ± 6.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#156 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 1, 2025

Hi both,

on newest everything, using your code @CoastEgo, the direct adjoint is a little faster for me (126 ms vs 156 ms). This is when benchmarking make_step, exactly as you do.

However, this does not carry over to simply solving the differential equation, or to taking its derivative. If I run the neural CDE from the example, the integration alone is ~20-25 % faster with the checkpointed adjoint. This increases to 33 % for the gradients. You can paste this code below yours:

key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[-1]
sample_coeffs = tuple(c[-1] for c in coeffs)

recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)

run_recursive = eqx.filter_jit(recursive)
run_direct = eqx.filter_jit(direct)
run_grad_recursive = eqx.filter_jit(eqx.filter_grad(recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(direct))

_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)

print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit run_grad_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint with gradients")
%timeit run_grad_direct(sample_ts, sample_coeffs).block_until_ready()

Edit: update to switch to a better example, and include gradients.

@CoastEgo
Copy link
Author

CoastEgo commented Jan 2, 2025

Hi @johannahaffner, with newest packages, I got the same result with your code. I guess the different benchmark results between make_step() and NeuralCDE.__call__() are because of two reasons.

  1. The first reason is that inside make_step(), we use `jax.vmap' version of model. If I run your code, I will get the result like this. (recursive is twice faster!)
Timing recursive checkpoint adjoint
3 ms ± 53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint
3.54 ms ± 222 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing recursive checkpoint adjoint with gradients
9.8 ms ± 447 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint with gradients
18.5 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

But if I run vmap version of your code, the result is different (direct adjoint is slightly faster). This may be related with while_loop inside vmap?

batch_size = 32
key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[:batch_size]

sample_coeffs = tuple(c[:batch_size] for c in coeffs)

recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)
@eqx.filter_jit
def run_recursive(ts, coeffs):
    result = jax.vmap(recursive)(ts, coeffs)
    return result.mean()
@eqx.filter_jit
def run_direct(ts, coeffs):
    result = jax.vmap(direct)(ts, coeffs)
    return result.mean()

run_grad_recursive = eqx.filter_jit(eqx.filter_grad(run_recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(run_direct))

_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)

print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit jax.block_until_ready(run_grad_recursive(sample_ts, sample_coeffs))
print("Timing direct adjoint with gradients")
%timeit jax.block_until_ready(run_grad_direct(sample_ts, sample_coeffs))
# Timing recursive checkpoint adjoint
# 52.8 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint
# 52.9 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing recursive checkpoint adjoint with gradients
# 169 ms ± 7.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint with gradients
# 147 ms ± 5.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Edit: delete the second reason. It turns out to be my mistake

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 2, 2025

I've just ran this on my machine and:

  1. With the benchmarking code provided by @CoastEgo then I can reproduce the difference between DirectAdjoint and RecursiveCheckpointAdjoint. That's unfortunate!
  2. When I replace the vmap with a single model invocation -- complete hack, but this:
    @eqx.filter_jit
    def loss(model, ti, label_i, coeff_i):
        ti, label_i, coeff_i = jax.tree.map(lambda x: x[0], (ti, label_i, coeff_i))
        pred = model(ti, coeff_i)[None]
        label_i = label_i[None]
        ...
    then the performance is again in RecursiveCheckpointAdjoint's favour. So indeed it definitely seems like a vmap issue.
  3. I think we'll need to simplify this down, e.g. to just an ODE, or even just a direct invocation of the underlying eqxi.while_loop(..., kind="bounded") (=DirectAdjoint) vs eqxi.while_loop(..., kind="checkpointed") (=RecursiveCheckpointAdjoint).

At least for the speed difference between Equinox 0.11.4 and Equinox 0.11.11 then I'm able to resolve this one pretty easily :) I think all that's happening is that you're also picking up a change in JAX version at the same time, and modern JAX has a known performance bug that is resolved by setting

import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"

with this I'm able to get your first faster set of numbers even on the latest versions of JAX and Equinox. I'll go ahead and ask over on the JAX issue tracker what the plan is for that. (EDIT: jax-ml/jax#25711)

Note that your example does have a small bug by the way: you have a generator data_i = (a[:batch_size] for a in data_all) rather than a tuple data_i = tuple(a[:batch_size] for a in data_all). Strangely this doesn't seem to impact the results for me,

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 2, 2025

Edit: my bad, the example below only has recursive checkpoint adjoint loose its edge under vmap if not jitted. Just noticed it was missing this after I had posted.

Under JIT, I can now not reproduce the performance-drop-under-vmap issue when using the introductory CDE example with a quadratic path.

import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp


class QuadraticPath(dfx.AbstractPath):
    @property
    def t0(self):
        return 0

    @property
    def t1(self):
        return 3

    def evaluate(self, t0, t1=None, left=True):
        del left
        if t1 is not None:
            return self.evaluate(t1) - self.evaluate(t0)
        return t0 ** 2

# args unused - can be anything
vector_field = lambda t, y, args: -y

control = QuadraticPath()
term = dfx.ControlTerm(vector_field, control).to_ode()
solver = dfx.Dopri5()

def make_solve(term, solver, adjoint, vmap=False):
    def solve(args):
        return dfx.diffeqsolve(term, solver, 0, 3, 0.05, 1, args, adjoint=adjoint).ys
    if vmap:
        return eqx.filter_jit(eqx.filter_vmap(solve))
    else:
        return eqx.filter_jit(solve)

direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()

direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)

_ = direct_solve(None)  # warmup
_ = recursive_solve(None)
dummy_args = jnp.zeros((8,))
_ = vmap_direct_solve(dummy_args)
_ = vmap_recursive_solve(dummy_args)

print("Timing direct adjoint: quadratic path")  
%timeit direct_solve(None).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path")  
%timeit recursive_solve(None).block_until_ready()
print("Timing direct adjoint: quadratic path (vmap)")
%timeit vmap_direct_solve(dummy_args).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
%timeit vmap_recursive_solve(dummy_args).block_until_ready()

@lockwo
Copy link
Contributor

lockwo commented Jan 2, 2025

Also looking at the just while loops I see the same thing (i.e. checkpoint being faster)

Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"

import equinox as eqx
import equinox.internal as eqxi
import jax.numpy as jnp
import jax

t0 = 0.0
t1 = 3.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)
param = -1.0

def cond_fun(carry):
    i, y_array, y_cur = carry
    return i < N

def body_fun(carry):
    i, y_array, y_cur = carry
    y_next = y_cur + dt * (param * y_cur)
    return i+1, y_array.at[i+1].set(y_next), y_next

def cond_fun(carry):
    i, y_array, y_cur = carry
    return i < N

def body_fun(carry):
    i, y_array, y_cur = carry
    y_next = y_cur + dt * -y_cur
    return i+1, y_array.at[i+1].set(y_next), y_next

init_val = (0, jnp.zeros(N+1), 1.0)
_, y_sol, _ = jax.lax.while_loop(cond_fun, body_fun, init_val)

def make_solve(loop_type, vmap=False):
    def solve(args):
        #return jax.lax.while_loop(cond_fun, body_fun, args)[1]
        return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=4096)[1]
    if vmap:
        return eqx.filter_jit(eqx.filter_vmap(solve, in_axes=((None, 0, None),)))
    else:
        return eqx.filter_jit(solve)


direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)

_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (0, jnp.zeros((800, N+1)), 1.0)
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()

print("Bounded while single")  
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")  
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
Bounded while single
112 µs ± 2.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Checkpoint while single
112 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Bounded while vmap
151 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Checkpoint while vmap
118 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

One thing that might be causing a difference between @johannahaffner and mine simpler recreations and the original code is the adaptivity. I tested without adaptivity

Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"

import math
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax  # https://github.com/deepmind/optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP
    data_size: int
    hidden_size: int

    def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.data_size = data_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func
    linear: eqx.nn.Linear
    adjoint_state: int
    def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
        self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
        self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
        self.adjoint_state = adjoint_state
    def __call__(self, ts, coeffs):
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[-1] / 100

        if self.adjoint_state == 0:
            adjoint = diffrax.RecursiveCheckpointAdjoint()
        else:
            adjoint = diffrax.DirectAdjoint()
        
        y0 = self.initial(control.evaluate(ts[0]))
        saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            # stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
            adjoint=adjoint,
        )
        (prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
        return prediction

def get_data(dataset_size, add_noise, *, key):
    theta_key, noise_key = jr.split(key, 2)
    length = 100
    theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
    y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
    ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
    matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
    ys = jax.vmap(
        lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
    )(y0, ts)
    ys = jnp.concatenate([ts[:, :, None], ys], axis=-1)  # time is a channel
    ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
    if add_noise:
        ys = ys + jr.normal(noise_key, ys.shape) * 0.1
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
    labels = jnp.zeros((dataset_size,))
    labels = labels.at[: dataset_size // 2].set(1.0)
    _, _, data_size = ys.shape
    return ts, coeffs, labels, data_size


def main(
    dataset_size=256,
    add_noise=False,
    batch_size=32,
    lr=1e-2,
    steps=20,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)

    ts, coeffs, labels, data_size = get_data(
        dataset_size, add_noise, key=train_data_key
    )

    @eqx.filter_jit
    def loss(model, ti, label_i, coeff_i):
        pred = jax.vmap(model)(ti, coeff_i)
        bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((pred > 0.5) == (label_i == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model, data_i, opt_state):
        ti, label_i, *coeff_i = data_i
        (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state

    optim = optax.adam(lr)
    data_all = (ts, labels) + coeffs
    data_i = tuple(a[:batch_size] for a in data_all)

    print('recursive checkpoint adjoint')
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
    
    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))

    print('direct adjoint')
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))

    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()

and it didn't seem to make much of a difference. Next I thought, maybe size/code of VF? Since neural network is much more compute intensive then our two toy examples. You do see some time difference here (checkpoint is slower than bounded which it wasn't before, but not by much, idk if that's expected).

Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"

import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.nn as jnn
import jax.numpy as jnp

print(jax.__version__)
print(eqx.__version__)

key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(
    1,
    1,
    128,
    8,
    activation=jnn.softplus,
    final_activation=jnn.tanh,
    key=key,
)

t0 = 0.0
t1 = 30.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)

def cond_fun(carry):
    i, y_array, y_cur = carry
    return i < N

def body_fun(carry):
    i, y_array, y_cur = carry
    y_next = y_cur + dt * model(jnp.array([y_cur]))[0]
    return i + 1, y_array.at[i+1].set(y_next), y_next

init_val = (jnp.zeros(N+1), jnp.array(1.0))

def make_solve(loop_type, vmap=False):
    def solve(args):
        args = (0, *args)
        max_steps = 4096
        return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=max_steps)[1]
    if vmap:
        return eqx.filter_jit(eqx.filter_vmap(solve))
    else:
        return eqx.filter_jit(solve)

direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)

_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (jnp.zeros((32, N+1)), jnp.ones(32))
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()

print("Bounded while single")  
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")  
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
0.4.38
0.11.11
Bounded while single
3.57 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Checkpoint while single
3.6 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Bounded while vmap
73.2 ms ± 259 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Checkpoint while vmap
78.8 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Then lastly, I thought maybe it isn't the loop that's slow but the gradient (since that's also a difference, we just did inference). So if you remove the gradient calculation, you see both loops are the same speed. (again, not sure what it totally expected, but hopefully this provides some more useful datapoints).

Code + results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"

import math
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax  # https://github.com/deepmind/optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP
    data_size: int
    hidden_size: int

    def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.data_size = data_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func
    linear: eqx.nn.Linear
    adjoint_state: int
    def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
        self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
        self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
        self.adjoint_state = adjoint_state
    def __call__(self, ts, coeffs):
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[-1] / 100

        if self.adjoint_state == 0:
            adjoint = diffrax.RecursiveCheckpointAdjoint()
        else:
            adjoint = diffrax.DirectAdjoint()
        
        y0 = self.initial(control.evaluate(ts[0]))
        saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
            adjoint=adjoint,
        )
        (prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
        return prediction

def get_data(dataset_size, add_noise, *, key):
    theta_key, noise_key = jr.split(key, 2)
    length = 100
    theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
    y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
    ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
    matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
    ys = jax.vmap(
        lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
    )(y0, ts)
    ys = jnp.concatenate([ts[:, :, None], ys], axis=-1)  # time is a channel
    ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
    if add_noise:
        ys = ys + jr.normal(noise_key, ys.shape) * 0.1
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
    labels = jnp.zeros((dataset_size,))
    labels = labels.at[: dataset_size // 2].set(1.0)
    _, _, data_size = ys.shape
    return ts, coeffs, labels, data_size


def main(
    dataset_size=256,
    add_noise=False,
    batch_size=32,
    lr=1e-2,
    steps=20,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)

    ts, coeffs, labels, data_size = get_data(
        dataset_size, add_noise, key=train_data_key
    )

    @eqx.filter_jit
    def loss(model, ti, label_i, coeff_i):
        pred = jax.vmap(model)(ti, coeff_i)
        bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((pred > 0.5) == (label_i == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model, data_i, opt_state):
        ti, label_i, *coeff_i = data_i
        # (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
        bxe, acc = loss(model, ti, label_i, coeff_i)
        return bxe, acc, model, opt_state
        # updates, opt_state = optim.update(grads, opt_state)
        # model = eqx.apply_updates(model, updates)
        # return bxe, acc, model, opt_state

    optim = optax.adam(lr)
    data_all = (ts, labels) + coeffs
    data_i = tuple(a[:batch_size] for a in data_all)

    print('recursive checkpoint adjoint')
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
    
    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))

    print('direct adjoint')
    model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))

    %timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()
recursive checkpoint adjoint
18.8 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
direct adjoint
18.6 ms ± 40.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@johannahaffner
Copy link
Contributor

That looks really thorough, @lockwo!

I was also looking into the adaptivity, but only got something very preliminary. I could also not immediately confirm this by switching to constant step sizes.

For a small batch and data set size (8 and 32), it looks like more steps are taken with the recursive adjoint. Since the number of steps passes a power of 2 at 64, maybe this is a memory allocation thing? I won't have time to look at this in more detail before the weekend though, so all I can do is leaving you with this very ugly plot.

Screenshot 2025-01-02 at 20 31 41

@patrick-kidger
Copy link
Owner

Since you touch on the magic number 64: FWIW DirectAdjoint is implemented in terms of nested jax.lax.scans of length 8, and therefore starts getting different speed/memory (but hopefully not steps) past every power of 8.

But if we're getting different numbers of steps between the adjoints then something has gone very wrong! These should be identical in terms of the mathematics they compute.

@lockwo
Copy link
Contributor

lockwo commented Jan 3, 2025

Do you have MVC for the different steps? I will say vmap of double while loops (like with adaptive stepping) I have encountered some interesting edge cases with before, but my minimal example seems to work as expected.

Code
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
import jax

def vector_field(t, y, args):
    prey, predator = y[0], y[1]
    α, β, γ, δ = (0.1, 0.02, 0.4, 0.02)
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.array([d_prey, d_predator])
    return d_y

term = dfx.ODETerm(vector_field)
solver = dfx.Dopri5()

def make_solve(term, solver, adjoint, vmap=False):
    def solve(init):
        return dfx.diffeqsolve(term, solver, 0, 200, 0.05, init, 
            None, adjoint=adjoint, stepsize_controller=dfx.PIDController(1e-3, 1e-6))
    if vmap:
        return eqx.filter_jit(eqx.filter_vmap(solve))
    else:
        return eqx.filter_jit(solve)

direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()

direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)


key = jax.random.key(0)
key, subkey = jax.random.split(key)
inits = jax.random.uniform(subkey, (2,))
_ = direct_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = recursive_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])

key, subkey = jax.random.split(key)
vmap_inits = jax.random.uniform(subkey, (10, 2))
_ = vmap_direct_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = vmap_recursive_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])

# print("Timing direct adjoint: quadratic path")  
# %timeit direct_solve(None).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path")  
# %timeit recursive_solve(None).block_until_ready()
# print("Timing direct adjoint: quadratic path (vmap)")
# %timeit vmap_direct_solve(dummy_args).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
# %timeit vmap_recursive_solve(dummy_args).block_until_ready()

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 3, 2025

The plot above is parsed output from jax.debug.print statements placed in NeuralCDE.__call__ 😅

Very preliminary, and not rigorous!

@patrick-kidger
Copy link
Owner

@johannahaffner , does your example here involve a gradient calculation? If so then that will explain the discrepancy in number of steps.

Whilst they'll both take the exact same number of steps in the diffeq solve (as I think @lockwo 's example demonstrates), they do backpropagate in slightly different ways, by design, with DirectAdjoint being asymptotically less efficient in number of vector field evaluations (which is why I don't recommend it).

@johannahaffner
Copy link
Contributor

It does, it runs %timeit make_step which does call grad_loss. Makes sense that this would differ!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants