-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question: DirectAdjoint is faster than RecursiveCheckpointAdjoint? #549
Comments
Hi, what exactly are you benchmarking? The runtime of Comparing the runtimes of Both times I used |
Hi, total_time = 0
for step, data_i in zip(
range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
):
start = time.time()
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
end = time.time()
time_i = end - start
if step > 0: # don't count the compilation time
total_time += time_i
print(
f"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: "
f"{end - start}"
)
print(f'total time: {total_time}') If I benchmark the %timeit jax.block_until_ready(make_step(model, data_i, opt_state)) The result is 64.2ms for |
Have you called run_fn = eqx.filter_jit(fn)
_ = run_fn(inputs)
%timeit run_fn(inputs).block_until_ready() |
Yes, I add this command below the training steps like this. for step, data_i in zip(
range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
):
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state)) And the |
I cannot reproduce this, I still get equivalent runtimes if I'm adding a How/where do you specify which adjoint to use? |
@CoastEgo -- do you have a copy-pastable MWE? (E.g. by starting with the neural CDE example and then minimising it down to something that fits in a GitHub message.) Just to be sure we're all running exactly the same code! Also, what version of Equinox are you using? The underlying implementation of the while loops (which these adjoint methods use) belong to Equinox. It's definitely expected that |
Even for such a small example? The solver just takes 50 steps. |
Yup! In fact especially so. The cost of |
Sorry for the confusion! Here is the code click hereimport math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = None
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
# adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
# Training loop like normal.
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
# Binary cross-entropy
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
(bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = (a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
data_i = (a[:batch_size] for a in data_all)
print('direct adjoint')
optim = optax.adam(lr)
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main() This will give a result with the version of equinox == 0.11.4 #recursive checkpoint adjoint
#74.4 ms ± 95.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#46.5 ms ± 43.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) When I upgrade all of my packages #recursive checkpoint adjoint
#160 ms ± 6.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#156 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Hi both, on newest everything, using your code @CoastEgo, the direct adjoint is a little faster for me (126 ms vs 156 ms). This is when benchmarking However, this does not carry over to simply solving the differential equation, or to taking its derivative. If I run the neural CDE from the example, the integration alone is ~20-25 % faster with the checkpointed adjoint. This increases to 33 % for the gradients. You can paste this code below yours: key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[-1]
sample_coeffs = tuple(c[-1] for c in coeffs)
recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)
run_recursive = eqx.filter_jit(recursive)
run_direct = eqx.filter_jit(direct)
run_grad_recursive = eqx.filter_jit(eqx.filter_grad(recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(direct))
_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)
print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit run_grad_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint with gradients")
%timeit run_grad_direct(sample_ts, sample_coeffs).block_until_ready() Edit: update to switch to a better example, and include gradients. |
Hi @johannahaffner, with newest packages, I got the same result with your code. I guess the different benchmark results between
Timing recursive checkpoint adjoint
3 ms ± 53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint
3.54 ms ± 222 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing recursive checkpoint adjoint with gradients
9.8 ms ± 447 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint with gradients
18.5 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) But if I run batch_size = 32
key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[:batch_size]
sample_coeffs = tuple(c[:batch_size] for c in coeffs)
recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)
@eqx.filter_jit
def run_recursive(ts, coeffs):
result = jax.vmap(recursive)(ts, coeffs)
return result.mean()
@eqx.filter_jit
def run_direct(ts, coeffs):
result = jax.vmap(direct)(ts, coeffs)
return result.mean()
run_grad_recursive = eqx.filter_jit(eqx.filter_grad(run_recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(run_direct))
_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)
print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit jax.block_until_ready(run_grad_recursive(sample_ts, sample_coeffs))
print("Timing direct adjoint with gradients")
%timeit jax.block_until_ready(run_grad_direct(sample_ts, sample_coeffs))
# Timing recursive checkpoint adjoint
# 52.8 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint
# 52.9 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing recursive checkpoint adjoint with gradients
# 169 ms ± 7.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint with gradients
# 147 ms ± 5.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Edit: delete the second reason. It turns out to be my mistake |
I've just ran this on my machine and:
At least for the speed difference between Equinox 0.11.4 and Equinox 0.11.11 then I'm able to resolve this one pretty easily :) I think all that's happening is that you're also picking up a change in JAX version at the same time, and modern JAX has a known performance bug that is resolved by setting import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false" with this I'm able to get your first faster set of numbers even on the latest versions of JAX and Equinox. I'll go ahead and ask over on the JAX issue tracker what the plan is for that. (EDIT: jax-ml/jax#25711) Note that your example does have a small bug by the way: you have a generator |
Edit: my bad, the example below only has recursive checkpoint adjoint loose its edge under vmap if not jitted. Just noticed it was missing this after I had posted. Under JIT, I can now not reproduce the performance-drop-under-vmap issue when using the introductory CDE example with a quadratic path. import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
class QuadraticPath(dfx.AbstractPath):
@property
def t0(self):
return 0
@property
def t1(self):
return 3
def evaluate(self, t0, t1=None, left=True):
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)
return t0 ** 2
# args unused - can be anything
vector_field = lambda t, y, args: -y
control = QuadraticPath()
term = dfx.ControlTerm(vector_field, control).to_ode()
solver = dfx.Dopri5()
def make_solve(term, solver, adjoint, vmap=False):
def solve(args):
return dfx.diffeqsolve(term, solver, 0, 3, 0.05, 1, args, adjoint=adjoint).ys
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()
direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)
_ = direct_solve(None) # warmup
_ = recursive_solve(None)
dummy_args = jnp.zeros((8,))
_ = vmap_direct_solve(dummy_args)
_ = vmap_recursive_solve(dummy_args)
print("Timing direct adjoint: quadratic path")
%timeit direct_solve(None).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path")
%timeit recursive_solve(None).block_until_ready()
print("Timing direct adjoint: quadratic path (vmap)")
%timeit vmap_direct_solve(dummy_args).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
%timeit vmap_recursive_solve(dummy_args).block_until_ready() |
Also looking at the just while loops I see the same thing (i.e. checkpoint being faster) Code + Resultsimport os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import equinox as eqx
import equinox.internal as eqxi
import jax.numpy as jnp
import jax
t0 = 0.0
t1 = 3.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)
param = -1.0
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * (param * y_cur)
return i+1, y_array.at[i+1].set(y_next), y_next
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * -y_cur
return i+1, y_array.at[i+1].set(y_next), y_next
init_val = (0, jnp.zeros(N+1), 1.0)
_, y_sol, _ = jax.lax.while_loop(cond_fun, body_fun, init_val)
def make_solve(loop_type, vmap=False):
def solve(args):
#return jax.lax.while_loop(cond_fun, body_fun, args)[1]
return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=4096)[1]
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve, in_axes=((None, 0, None),)))
else:
return eqx.filter_jit(solve)
direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)
_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (0, jnp.zeros((800, N+1)), 1.0)
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()
print("Bounded while single")
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
One thing that might be causing a difference between @johannahaffner and mine simpler recreations and the original code is the adaptivity. I tested without adaptivity Code + Resultsimport os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[-1] / 100
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
# stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
(bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = tuple(a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
print('direct adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main() and it didn't seem to make much of a difference. Next I thought, maybe size/code of VF? Since neural network is much more compute intensive then our two toy examples. You do see some time difference here (checkpoint is slower than bounded which it wasn't before, but not by much, idk if that's expected). Code + Resultsimport os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.nn as jnn
import jax.numpy as jnp
print(jax.__version__)
print(eqx.__version__)
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(
1,
1,
128,
8,
activation=jnn.softplus,
final_activation=jnn.tanh,
key=key,
)
t0 = 0.0
t1 = 30.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * model(jnp.array([y_cur]))[0]
return i + 1, y_array.at[i+1].set(y_next), y_next
init_val = (jnp.zeros(N+1), jnp.array(1.0))
def make_solve(loop_type, vmap=False):
def solve(args):
args = (0, *args)
max_steps = 4096
return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=max_steps)[1]
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)
_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (jnp.zeros((32, N+1)), jnp.ones(32))
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()
print("Bounded while single")
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
Then lastly, I thought maybe it isn't the loop that's slow but the gradient (since that's also a difference, we just did inference). So if you remove the gradient calculation, you see both loops are the same speed. (again, not sure what it totally expected, but hopefully this provides some more useful datapoints). Code + resultsimport os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[-1] / 100
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
# (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
bxe, acc = loss(model, ti, label_i, coeff_i)
return bxe, acc, model, opt_state
# updates, opt_state = optim.update(grads, opt_state)
# model = eqx.apply_updates(model, updates)
# return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = tuple(a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
print('direct adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()
|
That looks really thorough, @lockwo! I was also looking into the adaptivity, but only got something very preliminary. I could also not immediately confirm this by switching to constant step sizes. For a small batch and data set size (8 and 32), it looks like more steps are taken with the recursive adjoint. Since the number of steps passes a power of 2 at 64, maybe this is a memory allocation thing? I won't have time to look at this in more detail before the weekend though, so all I can do is leaving you with this very ugly plot. |
Since you touch on the magic number 64: FWIW But if we're getting different numbers of steps between the adjoints then something has gone very wrong! These should be identical in terms of the mathematics they compute. |
Do you have MVC for the different steps? I will say vmap of double while loops (like with adaptive stepping) I have encountered some interesting edge cases with before, but my minimal example seems to work as expected. Codeimport diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
import jax
def vector_field(t, y, args):
prey, predator = y[0], y[1]
α, β, γ, δ = (0.1, 0.02, 0.4, 0.02)
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = jnp.array([d_prey, d_predator])
return d_y
term = dfx.ODETerm(vector_field)
solver = dfx.Dopri5()
def make_solve(term, solver, adjoint, vmap=False):
def solve(init):
return dfx.diffeqsolve(term, solver, 0, 200, 0.05, init,
None, adjoint=adjoint, stepsize_controller=dfx.PIDController(1e-3, 1e-6))
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()
direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)
key = jax.random.key(0)
key, subkey = jax.random.split(key)
inits = jax.random.uniform(subkey, (2,))
_ = direct_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = recursive_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
key, subkey = jax.random.split(key)
vmap_inits = jax.random.uniform(subkey, (10, 2))
_ = vmap_direct_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = vmap_recursive_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
# print("Timing direct adjoint: quadratic path")
# %timeit direct_solve(None).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path")
# %timeit recursive_solve(None).block_until_ready()
# print("Timing direct adjoint: quadratic path (vmap)")
# %timeit vmap_direct_solve(dummy_args).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
# %timeit vmap_recursive_solve(dummy_args).block_until_ready() |
The plot above is parsed output from Very preliminary, and not rigorous! |
@johannahaffner , does your example here involve a gradient calculation? If so then that will explain the discrepancy in number of steps. Whilst they'll both take the exact same number of steps in the diffeq solve (as I think @lockwo 's example demonstrates), they do backpropagate in slightly different ways, by design, with |
It does, it runs |
Hi,
According to the suggestions in adjoints docs, the
RecursiveCheckpointAdjoint
method, given enough checkpoints, should be faster thanDirectAdjoint
. But In my practice, it turns out thatDirectAdjoint
is faster. Is there wrong in my understanding?Example
An example can be shown using the neural_cde tutorial:
adjoint = RecursiveCheckpointAdjoint()
, the total run time (without the first step/compilation time ) is 1.38sadjoint = RecursiveCheckpointAdjoint(checkpoints=4096)
, the run time is 1.18sEnvironment
jax == 0.4.29
jaxlib == 0.4.29
diffrax == 0.6.0
The text was updated successfully, but these errors were encountered: