From 8138c92d097b726407258a7d5c1a7ed84bc2de8f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 7 Oct 2023 10:47:07 -0700 Subject: [PATCH 1/5] Fixed beartype complaining about nested loops --- equinox/internal/_loop/common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/equinox/internal/_loop/common.py b/equinox/internal/_loop/common.py index 5a50d9c8..f732d52b 100644 --- a/equinox/internal/_loop/common.py +++ b/equinox/internal/_loop/common.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any import jax import jax.core @@ -8,7 +8,7 @@ import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu -from jaxtyping import Array, Bool, Shaped +from jaxtyping import Array, Bool from ..._filters import combine, is_array, partition from ..._module import field, Module @@ -248,7 +248,8 @@ def _maybe_set(pred, xs, x, i, *, kwargs, makes_false_steps): class _Buffer(Module): - _array: Union[Shaped[Array, "..."], "_Buffer"] + # annotation removed because beartype can't handle the forward reference. + _array: Any # Union[Shaped[Array, "..."], _Buffer] _pred: Bool[Array, ""] _tag: object = field(static=True) _makes_false_steps: bool = field(static=True) From 7c608429675d12ddba1cad92554100e4307a2086 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 8 Oct 2023 19:36:24 -0700 Subject: [PATCH 2/5] CWL now compatible with partially nondifferentiable body functions. This is fairly finickity! So, the checkpointed while loop (CWL) needs to figure out which cotangents to propagate backward. There are two criteria for this: (1) only some inputs actually need cotangents at all. Any parts of the carry that are unrelated to those inputs can be skipped. (2) some cotangents might be resolvable as symbolic zeros, and can thus be skipped. Previously, we handled (2) but did not handle (1). In practice this was mostly fine. We just pretended all (inexact) inputs needed cotangents, and let anything superfluous get DCE'd at the end. However, this has one niche problem, which is that it is incompatible with `eqxi.nondifferentiable` -- which is used to mark an inexact input as nondifferentiable, and raise a trace-time error if we attempt to differentiate it. With this change, we're now a bit more careful: we run an additional fixed-point iteration to check which inputs need cotangents, and then skip the rest. Note that this extra fixed-point iteration is actually done using a JVP. This is designed to mirror what JAX does internally: perturbations are tracked using a fixed-point iteration inside the JVP rule for `lax.{while_loop,scan}`. I considered some other alternatives (using VJPs or `pe.dce_jaxpr`) but I think those either fail in corner cases or run into catch-22s. --- equinox/internal/_loop/checkpointed.py | 186 +++++++++++++++++++++---- tests/test_while_loop.py | 124 +++++++++++++++++ 2 files changed, 282 insertions(+), 28 deletions(-) diff --git a/equinox/internal/_loop/checkpointed.py b/equinox/internal/_loop/checkpointed.py index f734e69a..51e3c4ef 100644 --- a/equinox/internal/_loop/checkpointed.py +++ b/equinox/internal/_loop/checkpointed.py @@ -660,6 +660,10 @@ def _is_symbolic_zero_or_none(x): return isinstance(x, jax.custom_derivatives.SymbolicZero) or x is None +def _not_symbolic_zero(x): + return not isinstance(x, jax.custom_derivatives.SymbolicZero) + + def _materialise_symbolic_zero(is_zero, ct_x, x): if is_zero: assert ct_x is None @@ -753,7 +757,7 @@ def _checkpointed_while_loop_bwd( loading values from checkpoints and using treeverse to toggle between forward and backward steps. """ - _, perturb_body_fun = perturbed + perturb_init_val, perturb_body_fun = perturbed _, body_fun = vjp_arg grad_final_body_fun = jtu.tree_map( lambda x, p: jnp.zeros_like(x) if p else None, body_fun, perturb_body_fun @@ -915,46 +919,166 @@ def _to_vjp(_diff_body, _diff_val): init_step_val, init_step_grad_val, init_index, checkpoints ) - # Fixed-point iteration: promote symbolic zeros to materialised zeros if necessary. - # Do it inside the dynamic context of a `jax.eval_shape` so that it's as cheap as + # + # Okay, this next bit is a bit complicated. + # We need to figure out exactly which parts of the carry that we're going to compute + # cotangents for. + # There are three interacting pieces: + # (1) Only some inputs even need their cotangents computed (those that are marked + # as perturbed). We should avoid computing intermediate cotangents that aren't + # needed. + # (2) Some cotangents may be symbolic zeros. We should avoid materialising any that + # don't need to be materialised. + # (3) Some parts of the body function may not be differentiable. (E.g. they call to + # `eqxi.nondifferentiable` or `eqxi.nondifferentiable_backward`.) In particular, + # some inexact arrays may be marked with these. (Anything else is automatically + # nondifferentiable.) As these raise an error eagerly, we can't just take an + # approach of "differentiate everything, then DCE". + # + # Our strategy is as follows. First, consider those inputs that are perturbed, and + # iterate to find all the carries that become perturbed as a result. We obtain a + # fixed point. + # * This mimicks what JAX does in JVP rules for scans/whiles. And in fact we also + # track perturbation information forward via JVPs (which input tangents affect + # which output tangents), rather than backward via VJPs (which output cotangents + # affect which input cotangents), as this is what JAX does here and we'd like to + # be consistent. (Also, it seems like there might be a catch-22 trying to do this + # with VJPs -- how do you decide which inputs to perturb in your VJP?) This is + # basically handling criteria (1). + # * This also means that we're perturbing the minmal amount possible. We don't even + # run the JVP rule for something that isn't differentiated. This is part of + # addressing criteria (3). + # + # Second, we ignore all cotangents (of our forward output) that aren't on a + # perturbed carry. Clearly they're not going to affect anything. So we switch them + # to symbolic zeros. + # * This is part of addressing criteria (3) -- we avoid propagating cotangents + # that we don't need to. Maybe there's an `eqxi.nondifferentiable_backward` in + # there. + # + # Third, we find another fixed point, this time for the symbolic zeros. Track which + # cotangents affect which cotangents, and materialise the minimal amount possible. + # * This addresses criteria (2). + # * Note that symbolic zeros are doing double-duty here. We avoid materialising them + # for simple efficiency reasons, sure -- but we also want to keep as many + # unmaterialised as possible, in case there's an `eqxi.nondifferentiable_backward` + # in there, which only accepts symbolic zero cotangents. + # + # Phew! At that point, job done: we know which cotangents we're propagating and so + # we can run our loop. + # + # The first fixed point basically mimicks what JAX does in the JVP rule for a while + # or scan, figuring out which values are perturbed. The second fixed point basically + # mimicks what JAX does in the transpose rule of a scan, figuring out which + # cotangents are symbolic zeros. + # + + # Do this inside the dynamic context of a `jax.eval_shape` so that it's as cheap as + # possible, without extra tracing. + def _resolve_perturb_val(): + init_val_buffers = tree_at( + buffers(_is_none), init_val, filled_buffers, is_leaf=_is_none + ) + perturb_val = perturb_init_val + assert jtu.tree_structure(init_val_buffers) == jtu.tree_structure(perturb_val) + + while True: + # `body_fun` is included so that we also track perturbatations on that. + perturb_pair = (perturb_body_fun, perturb_val) + dynamic, static = partition((body_fun, init_val_buffers), perturb_pair) + new_perturb_val = sentinel = object() + + @jax.custom_jvp + def _record_symbolic_zeros(_dynamic_out): + return _dynamic_out + + def _record_symbolic_zeros_jvp(primals, tangents): + (primals,) = primals + (tangents,) = tangents + nonlocal new_perturb_val + new_perturb_val = jtu.tree_map(_not_symbolic_zero, tangents) + return primals, tangents + + _record_symbolic_zeros.defjvp( + _record_symbolic_zeros_jvp, symbolic_zeros=True + ) + + def _to_linearize(_dynamic): + _body_fun, _val = combine(_dynamic, static) + _out = _body_fun(_val) + _dynamic_out, _static_out = partition(_out, is_inexact_array) + _dynamic_out = _record_symbolic_zeros(_dynamic_out) + _out = combine(_dynamic_out, _static_out) + return _out + + # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s. + jax.linearize(_to_linearize, dynamic) + assert new_perturb_val is not sentinel + assert jtu.tree_structure( + perturb_val, is_leaf=_is_none + ) == jtu.tree_structure(new_perturb_val, is_leaf=_is_none) + new_perturb_val = jtu.tree_map( + lambda x, y: False if (x is False and y is None) else y, + perturb_val, + new_perturb_val, + ) + assert jtu.tree_structure(perturb_val) == jtu.tree_structure( + new_perturb_val + ) + if tree_equal(perturb_val, new_perturb_val): + jtu.tree_map(_assert_bool, perturb_val) + if getattr(typing, "TESTING", False): + print("perturb_val", perturb_val) + return Static(perturb_val) + else: + perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val) + + # Find a fixed point of which values we need to perturb. (Not the same as whether + # or not they have symbolic zero cotangents! All unperturbed values must have + # symbolic zero cotangents, but some perturbed values may have these as well.) + perturb_val = jax.eval_shape(_resolve_perturb_val).value + + # Do this inside the dynamic context of a `jax.eval_shape` so that it's as cheap as # possible, without extra tracing. def _resolve_symbolic_zeros(grad_val): symbolic_zero_gradient = jtu.tree_map(_is_none, grad_val, is_leaf=_is_none) - new_symbolic_zero_gradient = None init_val_buffers = tree_at( buffers(_is_none), init_val, filled_buffers, is_leaf=_is_none ) - dynamic_init_val, static_init_val = partition( - init_val_buffers, is_inexact_array - ) + dynamic_init_val, static_init_val = partition(init_val_buffers, perturb_val) - # Some hackery to extract which inputs still have symbolic zero cotangents. - @jax.custom_vjp - def _record_symbolic_zero(_dynamic_val): - return _dynamic_val + while True: + new_symbolic_zero_gradient = sentinel = object() - def _record_symbolic_zero_fwd(_dynamic_val): - return jtu.tree_map(_get_value, _dynamic_val), None + # Some hackery to extract which inputs still have symbolic zero cotangents. + @jax.custom_vjp + def _record_symbolic_zero(_dynamic_val): + return _dynamic_val - def _record_symbolic_zero_bwd(_, grad_dynamic_val): - nonlocal new_symbolic_zero_gradient - new_symbolic_zero_gradient = jtu.tree_map( - _is_symbolic_zero_or_none, grad_dynamic_val, is_leaf=_is_none - ) - return (grad_dynamic_val,) + def _record_symbolic_zero_fwd(_dynamic_val): + return jtu.tree_map(_get_value, _dynamic_val), None - _record_symbolic_zero.defvjp( - _record_symbolic_zero_fwd, _record_symbolic_zero_bwd, symbolic_zeros=True - ) + def _record_symbolic_zero_bwd(_, grad_dynamic_val): + nonlocal new_symbolic_zero_gradient + new_symbolic_zero_gradient = jtu.tree_map( + _is_symbolic_zero_or_none, grad_dynamic_val, is_leaf=_is_none + ) + return (grad_dynamic_val,) + + _record_symbolic_zero.defvjp( + _record_symbolic_zero_fwd, + _record_symbolic_zero_bwd, + symbolic_zeros=True, + ) - def _to_vjp(_dynamic_val): - _dynamic_val = _record_symbolic_zero(_dynamic_val) - val = combine(_dynamic_val, static_init_val) - return filter(body_fun(val), symbolic_zero_gradient, inverse=True) + def _to_vjp(_dynamic_val): + _dynamic_val = _record_symbolic_zero(_dynamic_val) + val = combine(_dynamic_val, static_init_val) + return filter(body_fun(val), symbolic_zero_gradient, inverse=True) - while True: _, vjp_fn = jax.vjp(_to_vjp, dynamic_init_val) - vjp_fn(grad_val) # get new_symbolic_zero_gradient via nonlocal + vjp_fn(grad_val) + assert new_symbolic_zero_gradient is not sentinel if tree_equal(symbolic_zero_gradient, new_symbolic_zero_gradient): jtu.tree_map(_assert_bool, symbolic_zero_gradient) if getattr(typing, "TESTING", False): @@ -971,6 +1095,12 @@ def _to_vjp(_dynamic_val): init_val_buffers, ) + grad_final_val = filter(grad_final_val, perturb_val) + # If all provided cotangents end up being irrelevant -- i.e. the perturbed inputs + # depend only on those outputs which have symbolic zero cotangents -- then we can + # skip the whole computation. + if len(jtu.tree_leaves(grad_final_val)) == 0: + return jtu.tree_map(lambda _: None, (grad_final_val, body_fun)) symbolic_zero_gradient = jax.eval_shape( _resolve_symbolic_zeros, grad_final_val ).value diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index d3a21b20..0f642342 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -726,3 +726,127 @@ def buffers(carry): eqxi.while_loop( cond_fun, body_fun, init, kind="checkpointed", buffers=buffers, max_steps=2 ) + + +# This test includes complexities like buffers, exact ararys, and `None`, just to be +# sure we handle these complexities here too. +def test_nondifferentiable_body1(): + def cond_fun(carry): + return True + + def body_fun(carry): + step, x, y, z, _ = carry + y2 = eqxi.nondifferentiable(y) + return step + 1, x + y2, y + 1, z.at[step].set(y), None + + @eqx.filter_jit + @jax.value_and_grad + def run(x__z, y_in, true): + x_in, z_in = x__z + init = (0, x_in, y_in, z_in, None) + if true: + out = _while_as_scan(cond_fun, body_fun, init, max_steps=3) + else: + out = eqxi.while_loop( + cond_fun, body_fun, init, max_steps=3, kind="checkpointed" + ) + _, x_out, y_out, z_out, none = out + assert none is None + return x_out + y_out + jnp.sum(z_out) + + x_in = jnp.array(1.2) + y_in = jnp.array(0.7) + z_in = jnp.array([-5.0, -5.0, -5.0]) + true = run((x_in, z_in), y_in, true=True) + outs = run((x_in, z_in), y_in, true=False) + assert shaped_allclose(true, outs) + + +def test_nondifferentiable_body2(capfd): + def cond_fun(carry): + return True + + # This function is set up so that (x, y, z) require multiple passes through to + # propagate which values are perturbed. + # This function is set up so that w has a cotangent that should be dropped. + def body_fun(carry): + x, y, z, w = carry + w = eqxi.nondifferentiable(w) + return x + 1, x + y, y + z, w * 2 + + @jax.jit + @jax.grad + def run(x, y, z, w): + x, y, z, w = eqxi.while_loop( + cond_fun, body_fun, (x, y, z, w), max_steps=3, kind="checkpointed" + ) + return y + w + + capfd.readouterr() + run(1.0, 1.0, 1.0, 1.0) + text, _ = capfd.readouterr() + assert "perturb_val (False, False, False, (True, True, True, False))" in text + assert ( + "symbolic_zero_gradient (True, True, True, (False, False, True, True))" in text + ) + + +def test_body_fun_grads(capfd): + def cond_fun(carry): + return True + + @eqx.filter_jit + @jax.grad + def run(x__y, true): + x, y = x__y + # `init` and `body_fun` are deliberately chosen so that `carry[0]` requires a + # gradient solely for the purpose of propagating that gradient back into `x`. + # (And in particular, not for propagating it back into `init`.) + # Thus this test is checking that we get gradients with respect to `body_fun` + # correctly. + init = (1.0, y) + + def body_fun(carry): + a, b = carry + return a * x, b + 1 + + if true: + final = _while_as_scan(cond_fun, body_fun, init, max_steps=3) + else: + final = eqxi.while_loop( + cond_fun, body_fun, init, max_steps=3, kind="checkpointed" + ) + return sum(final) + + x__y = (jnp.array(1.0), jnp.array(1.0)) + + capfd.readouterr() + outs = run(x__y, true=False) + text, _ = capfd.readouterr() + assert "perturb_val (False, False, False, (True, True))" in text + assert "symbolic_zero_gradient (True, True, True, (False, False))" in text + + true = run(x__y, true=True) + assert shaped_allclose(true, outs) + + +def test_trivial_vjp(capfd): + def cond_fun(carry): + return True + + def body_fun(carry): + return carry + + @jax.jit + @jax.grad + def run(x): + a, b = eqxi.while_loop( + cond_fun, body_fun, (x, 0.0), max_steps=3, kind="checkpointed" + ) + return b + + capfd.readouterr() + assert run(1.0) == 0 + text, _ = capfd.readouterr() + assert "perturb_val (False, False, False, (True, False))" in text + assert "symbolic_zero_gradient" not in text From 5c4a8423331461a9afde33695bbbb8fee3e090a1 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 8 Oct 2023 23:15:01 -0700 Subject: [PATCH 3/5] Drive-by: fix some eqx.debug.announce_transform crashes. --- equinox/debug/_announce_transform.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/equinox/debug/_announce_transform.py b/equinox/debug/_announce_transform.py index 7403a5eb..9e6af450 100644 --- a/equinox/debug/_announce_transform.py +++ b/equinox/debug/_announce_transform.py @@ -60,7 +60,7 @@ def announce_transform( def _impl(*x, stack, name, intermediates, announce): del intermediates - stack = stack + ("impl,") + stack = stack + ("impl",) announce(name + ":".join(stack)) return x @@ -81,9 +81,20 @@ def _jvp(p, t, *, stack, name, intermediates, announce): p_out = announce_jaxpr_p.bind( *p, stack=p_stack, name=name, intermediates=intermediates, announce=announce ) - t_out = announce_jaxpr_p.bind( - *t, stack=t_stack, name=name, intermediates=intermediates, announce=announce - ) + t_nonzero = [ti for ti in t if type(ti) is not ad.Zero] + if len(t_nonzero) > 0: + t_nonzero_out = announce_jaxpr_p.bind( + *t_nonzero, + stack=t_stack, + name=name, + intermediates=intermediates, + announce=announce, + ) + else: + t_nonzero_out = [] + t_nonzero_out = iter(t_nonzero_out) + t_out = [ti if type(ti) is ad.Zero else next(t_nonzero_out) for ti in t] + assert next(t_nonzero_out, None) is None return p_out, t_out From e1f6e229eed628dc84badc9115f6d29d5d1389d7 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:00:16 -0700 Subject: [PATCH 4/5] CWL now handles disable_jit correctly. --- equinox/internal/_loop/common.py | 20 +++++++++----------- tests/test_while_loop.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/equinox/internal/_loop/common.py b/equinox/internal/_loop/common.py index f732d52b..69e892a5 100644 --- a/equinox/internal/_loop/common.py +++ b/equinox/internal/_loop/common.py @@ -322,17 +322,11 @@ def _unwrap_buffers(x): return x -# Work around JAX issue #15676 -@jax.custom_jvp -def fixed_asarray(x): - return jnp.asarray(x) - - -@fixed_asarray.defjvp -def _fixed_asarray_jvp(x, tx): - (x,) = x - (tx,) = tx - return fixed_asarray(x), fixed_asarray(tx) +# Work around JAX issue #15676. +# This issue arises with both JVP tracing and make_jaxpr tracing. The former can be +# handled with a custom_jvp, but the latter cannot. So we need to just call `jnp.array` +# instead. +fixed_asarray = jnp.array def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps): @@ -417,6 +411,10 @@ def unwrap_and_select(leaf, leaf2): step, pred, _, val = val buffer_val = _wrap_buffers(val, pred, tag) buffer_val2 = body_fun(buffer_val) + # Needed to work with `disable_jit`, as then we lose the automatic + # ArrayLike->Array cast provided by JAX's while loops. + # The input `val` is already cast to Array below, so this matches that. + buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2) # Strip `.named_shape`; c.f. Diffrax issue #246 struct = jax.eval_shape(lambda: buffer_val) struct2 = jax.eval_shape(lambda: buffer_val2) diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index 0f642342..b8fd985d 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -712,6 +712,17 @@ def run(init_carry): ) in text +def test_disable_jit(): + def cond_fun(carry): + return True + + def body_fun(carry): + return 5 + + with jax.disable_jit(): + eqxi.while_loop(cond_fun, body_fun, 3, max_steps=3, kind="checkpointed") + + def test_buffer_index(): def cond_fun(carry): return True From ad9b5a31063c9ce10ca74be1b79557f16e6d326f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:00:37 -0700 Subject: [PATCH 5/5] Drive-by: fix wrong type annotation (now caught with jaxtyping update) --- equinox/nn/_normalisation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/nn/_normalisation.py b/equinox/nn/_normalisation.py index 88865e24..589d5160 100644 --- a/equinox/nn/_normalisation.py +++ b/equinox/nn/_normalisation.py @@ -50,7 +50,7 @@ class LayerNorm(Module): """ - shape: tuple[int] = field(static=True) + shape: tuple[int, ...] = field(static=True) eps: float = field(static=True) use_weight: bool = field(static=True) use_bias: bool = field(static=True)